Feat: OPENVPN_INTERFACE defaulting to tun0

- Fix: custom config with custom network interface name for firewall
- Keep VPN tunnel interface in firewall state
- Vul fix: only allow traffic through vpn interface when needed
- Adapt code to adapt to network interface name
- Remove outdated TUN and TAP constants
This commit is contained in:
Quentin McGaw (desktop)
2021-08-19 23:22:55 +00:00
parent 7191d4e911
commit bec8ff27ae
20 changed files with 219 additions and 89 deletions

View File

@@ -76,6 +76,7 @@ ENV VPNSP=pia \
OPENVPN_TARGET_IP= \ OPENVPN_TARGET_IP= \
OPENVPN_IPV6=off \ OPENVPN_IPV6=off \
OPENVPN_CUSTOM_CONFIG= \ OPENVPN_CUSTOM_CONFIG= \
OPENVPN_INTERFACE=tun0 \
TZ= \ TZ= \
PUID= \ PUID= \
PGID= \ PGID= \

View File

@@ -290,7 +290,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
} }
for _, vpnPort := range allSettings.Firewall.VPNInputPorts { for _, vpnPort := range allSettings.Firewall.VPNInputPorts {
err = firewallConf.SetAllowedPort(ctx, vpnPort, string(constants.TUN)) err = firewallConf.SetAllowedPort(ctx, vpnPort, allSettings.VPN.OpenVPN.Interface)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -1,7 +1,9 @@
package configuration package configuration
import ( import (
"errors"
"fmt" "fmt"
"regexp"
"strconv" "strconv"
"strings" "strings"
@@ -26,6 +28,7 @@ type OpenVPN struct {
EncPreset string `json:"encryption_preset"` // PIA EncPreset string `json:"encryption_preset"` // PIA
IPv6 bool `json:"ipv6"` // Mullvad IPv6 bool `json:"ipv6"` // Mullvad
ProcUser string `json:"procuser"` // Process username ProcUser string `json:"procuser"` // Process username
Interface string `json:"interface"`
} }
func (settings *OpenVPN) String() string { func (settings *OpenVPN) String() string {
@@ -39,6 +42,8 @@ func (settings *OpenVPN) lines() (lines []string) {
lines = append(lines, indent+lastIndent+"Verbosity level: "+strconv.Itoa(settings.Verbosity)) lines = append(lines, indent+lastIndent+"Verbosity level: "+strconv.Itoa(settings.Verbosity))
lines = append(lines, indent+lastIndent+"Network interface: "+settings.Interface)
if len(settings.Flags) > 0 { if len(settings.Flags) > 0 {
lines = append(lines, indent+lastIndent+"Flags: "+strings.Join(settings.Flags, " ")) lines = append(lines, indent+lastIndent+"Flags: "+strings.Join(settings.Flags, " "))
} }
@@ -148,6 +153,11 @@ func (settings *OpenVPN) read(r reader, serviceProvider string) (err error) {
return fmt.Errorf("environment variable OPENVPN_IPV6: %w", err) return fmt.Errorf("environment variable OPENVPN_IPV6: %w", err)
} }
settings.Interface, err = readInterface(r.env)
if err != nil {
return err
}
settings.EncPreset, err = getPIAEncryptionPreset(r) settings.EncPreset, err = getPIAEncryptionPreset(r)
if err != nil { if err != nil {
return err return err
@@ -173,3 +183,22 @@ func readProtocol(env params.Env) (tcp bool, err error) {
} }
return protocol == constants.TCP, nil return protocol == constants.TCP, nil
} }
const openvpnIntfRegexString = `^.*[0-9]$`
var openvpnIntfRegexp = regexp.MustCompile(openvpnIntfRegexString)
var errInterfaceNameNotValid = errors.New("interface name is not valid")
func readInterface(env params.Env) (intf string, err error) {
intf, err = env.Get("OPENVPN_INTERFACE", params.Default("tun0"))
if err != nil {
return "", fmt.Errorf("environment variable OPENVPN_INTERFACE: %w", err)
}
if !openvpnIntfRegexp.MatchString(intf) {
return "", fmt.Errorf("%w: does not match regex %s: %s",
errInterfaceNameNotValid, openvpnIntfRegexString, intf)
}
return intf, nil
}

View File

@@ -29,7 +29,8 @@ func Test_OpenVPN_JSON(t *testing.T) {
"version": "", "version": "",
"encryption_preset": "", "encryption_preset": "",
"ipv6": false, "ipv6": false,
"procuser": "" "procuser": "",
"interface": ""
}`, string(data)) }`, string(data))
var out OpenVPN var out OpenVPN
err = json.Unmarshal(data, &out) err = json.Unmarshal(data, &out)

View File

@@ -22,7 +22,8 @@ func Test_Settings_lines(t *testing.T) {
Name: constants.Mullvad, Name: constants.Mullvad,
}, },
OpenVPN: OpenVPN{ OpenVPN: OpenVPN{
Version: constants.Openvpn25, Version: constants.Openvpn25,
Interface: "tun",
}, },
}, },
}, },
@@ -33,6 +34,7 @@ func Test_Settings_lines(t *testing.T) {
" |--OpenVPN:", " |--OpenVPN:",
" |--Version: 2.5", " |--Version: 2.5",
" |--Verbosity level: 0", " |--Verbosity level: 0",
" |--Network interface: tun",
" |--Mullvad settings:", " |--Mullvad settings:",
" |--OpenVPN selection:", " |--OpenVPN selection:",
" |--Protocol: udp", " |--Protocol: udp",

View File

@@ -1,10 +1,5 @@
package constants package constants
const (
TUN = "tun0"
TAP = "tap0"
)
const ( const (
AES128cbc = "aes-128-cbc" AES128cbc = "aes-128-cbc"
AES256cbc = "aes-256-cbc" AES256cbc = "aes-256-cbc"

View File

@@ -4,8 +4,6 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"github.com/qdm12/gluetun/internal/constants"
) )
var ( var (
@@ -109,9 +107,9 @@ func (c *Config) enable(ctx context.Context) (err error) {
if err = c.acceptOutputTrafficToVPN(ctx, c.defaultInterface, c.vpnConnection, remove); err != nil { if err = c.acceptOutputTrafficToVPN(ctx, c.defaultInterface, c.vpnConnection, remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err) return fmt.Errorf("cannot enable firewall: %w", err)
} }
} if err = c.acceptOutputThroughInterface(ctx, c.vpnIntf, remove); err != nil {
if err = c.acceptOutputThroughInterface(ctx, string(constants.TUN), remove); err != nil { return fmt.Errorf("cannot enable firewall: %w", err)
return fmt.Errorf("cannot enable firewall: %w", err) }
} }
for _, network := range c.localNetworks { for _, network := range c.localNetworks {

View File

@@ -40,6 +40,7 @@ type Config struct { //nolint:maligned
// State // State
enabled bool enabled bool
vpnConnection models.Connection vpnConnection models.Connection
vpnIntf string
outboundSubnets []net.IPNet outboundSubnets []net.IPNet
allowedInputPorts map[uint16]string // port to interface mapping allowedInputPorts map[uint16]string // port to interface mapping
stateMutex sync.Mutex stateMutex sync.Mutex

View File

@@ -8,10 +8,12 @@ import (
) )
type VPNConnectionSetter interface { type VPNConnectionSetter interface {
SetVPNConnection(ctx context.Context, connection models.Connection) error SetVPNConnection(ctx context.Context,
connection models.Connection, vpnIntf string) error
} }
func (c *Config) SetVPNConnection(ctx context.Context, connection models.Connection) (err error) { func (c *Config) SetVPNConnection(ctx context.Context,
connection models.Connection, vpnIntf string) (err error) {
c.stateMutex.Lock() c.stateMutex.Lock()
defer c.stateMutex.Unlock() defer c.stateMutex.Unlock()
@@ -34,10 +36,25 @@ func (c *Config) SetVPNConnection(ctx context.Context, connection models.Connect
} }
} }
c.vpnConnection = models.Connection{} c.vpnConnection = models.Connection{}
if c.vpnIntf != "" {
if err = c.acceptOutputThroughInterface(ctx, c.vpnIntf, remove); err != nil {
c.logger.Error("cannot remove outdated VPN interface from firewall: " + err.Error())
}
}
c.vpnIntf = ""
remove = false remove = false
if err := c.acceptOutputTrafficToVPN(ctx, c.defaultInterface, connection, remove); err != nil { if err := c.acceptOutputTrafficToVPN(ctx, c.defaultInterface, connection, remove); err != nil {
return fmt.Errorf("cannot set VPN connection through firewall: %w", err) return fmt.Errorf("cannot set VPN connection through firewall: %w", err)
} }
c.vpnConnection = connection c.vpnConnection = connection
if err = c.acceptOutputThroughInterface(ctx, vpnIntf, remove); err != nil {
return fmt.Errorf("cannot accept output traffic through interface %s: %w", vpnIntf, err)
}
c.vpnIntf = vpnIntf
return nil return nil
} }

View File

@@ -33,18 +33,15 @@ func (c Connection) OpenVPNProtoLine() (line string) {
} }
// UpdateEmptyWith updates each field of the connection where the // UpdateEmptyWith updates each field of the connection where the
// value is not set using the value from the other connection. // value is not set using the value given as arguments.
func (c *Connection) UpdateEmptyWith(connection Connection) { func (c *Connection) UpdateEmptyWith(ip net.IP, port uint16, protocol string) {
if c.IP == nil { if c.IP == nil {
c.IP = connection.IP c.IP = ip
} }
if c.Port == 0 { if c.Port == 0 {
c.Port = connection.Port c.Port = port
} }
if c.Protocol == "" { if c.Protocol == "" {
c.Protocol = connection.Protocol c.Protocol = protocol
}
if c.Hostname == "" {
c.Hostname = connection.Hostname
} }
} }

View File

@@ -14,18 +14,22 @@ var (
) )
func BuildConfig(settings configuration.OpenVPN) ( func BuildConfig(settings configuration.OpenVPN) (
lines []string, connection models.Connection, err error) { lines []string, connection models.Connection, intf string, err error) {
lines, err = readCustomConfigLines(settings.Config) lines, err = readCustomConfigLines(settings.Config)
if err != nil { if err != nil {
return nil, connection, fmt.Errorf("%w: %s", ErrReadCustomConfig, err) return nil, connection, "", fmt.Errorf("%w: %s", ErrReadCustomConfig, err)
} }
connection, err = extractConnectionFromLines(lines) connection, intf, err = extractDataFromLines(lines)
if err != nil { if err != nil {
return nil, connection, fmt.Errorf("%w: %s", ErrExtractConnection, err) return nil, connection, "", fmt.Errorf("%w: %s", ErrExtractConnection, err)
} }
lines = modifyCustomConfig(lines, settings, connection) if intf == "" {
intf = settings.Interface
}
return lines, connection, nil lines = modifyCustomConfig(lines, settings, connection, intf)
return lines, connection, intf, nil
} }

View File

@@ -27,18 +27,20 @@ func Test_BuildConfig(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
settings := configuration.OpenVPN{ settings := configuration.OpenVPN{
Cipher: "cipher", Cipher: "cipher",
MSSFix: 999, MSSFix: 999,
Config: file.Name(), Config: file.Name(),
Interface: "tun",
} }
lines, connection, err := BuildConfig(settings) lines, connection, intf, err := BuildConfig(settings)
assert.NoError(t, err) assert.NoError(t, err)
expectedLines := []string{ expectedLines := []string{
"keep me", "keep me",
"proto udp", "proto udp",
"remote 1.9.8.7 1194", "remote 1.9.8.7 1194",
"dev tun",
"mute-replay-warnings", "mute-replay-warnings",
"auth-nocache", "auth-nocache",
"pull-filter ignore \"auth-token\"", "pull-filter ignore \"auth-token\"",
@@ -60,4 +62,6 @@ func Test_BuildConfig(t *testing.T) {
Protocol: constants.UDP, Protocol: constants.UDP,
} }
assert.Equal(t, expectedConnection, connection) assert.Equal(t, expectedConnection, connection)
assert.Equal(t, "tun", intf)
} }

View File

@@ -15,23 +15,24 @@ var (
errRemoteLineNotFound = errors.New("remote line not found") errRemoteLineNotFound = errors.New("remote line not found")
) )
// extractConnectionFromLines always takes the first remote line only. func extractDataFromLines(lines []string) (
func extractConnectionFromLines(lines []string) ( connection models.Connection, intf string, err error) {
connection models.Connection, err error) {
for i, line := range lines { for i, line := range lines {
newConnectionData, err := extractConnectionFromLine(line) ip, port, protocol, intfFound, err := extractDataFromLine(line)
if err != nil { if err != nil {
return connection, fmt.Errorf("on line %d: %w", i+1, err) return connection, "", fmt.Errorf("on line %d: %w", i+1, err)
} }
connection.UpdateEmptyWith(newConnectionData)
if connection.Protocol != "" && connection.IP != nil { intf = intfFound
connection.UpdateEmptyWith(ip, port, protocol)
if connection.Protocol != "" && connection.IP != nil && intf != "" {
break break
} }
} }
if connection.IP == nil { if connection.IP == nil {
return connection, errRemoteLineNotFound return connection, "", errRemoteLineNotFound
} }
if connection.Protocol == "" { if connection.Protocol == "" {
@@ -45,32 +46,41 @@ func extractConnectionFromLines(lines []string) (
} }
} }
return connection, nil return connection, intf, nil
} }
var ( var (
errExtractProto = errors.New("failed extracting protocol from proto line") errExtractProto = errors.New("failed extracting protocol from proto line")
errExtractRemote = errors.New("failed extracting protocol from remote line") errExtractRemote = errors.New("failed extracting from remote line")
errExtractDev = errors.New("failed extracting network interface from dev line")
) )
func extractConnectionFromLine(line string) ( func extractDataFromLine(line string) (
connection models.Connection, err error) { ip net.IP, port uint16, protocol, intf string, err error) {
switch { switch {
case strings.HasPrefix(line, "proto "): case strings.HasPrefix(line, "proto "):
connection.Protocol, err = extractProto(line) protocol, err = extractProto(line)
if err != nil { if err != nil {
return connection, fmt.Errorf("%w: %s", errExtractProto, err) return nil, 0, "", "", fmt.Errorf("%w: %s", errExtractProto, err)
} }
return nil, 0, protocol, "", nil
// only take the first remote line case strings.HasPrefix(line, "remote "):
case strings.HasPrefix(line, "remote ") && connection.IP == nil: ip, port, protocol, err = extractRemote(line)
connection.IP, connection.Port, connection.Protocol, err = extractRemote(line)
if err != nil { if err != nil {
return connection, fmt.Errorf("%w: %s", errExtractRemote, err) return nil, 0, "", "", fmt.Errorf("%w: %s", errExtractRemote, err)
} }
return ip, port, protocol, "", nil
case strings.HasPrefix(line, "dev "):
intf, err = extractInterfaceFromLine(line)
if err != nil {
return nil, 0, "", "", fmt.Errorf("%w: %s", errExtractDev, err)
}
return nil, 0, "", intf, nil
} }
return connection, nil return nil, 0, "", "", nil
} }
var ( var (
@@ -137,3 +147,16 @@ func extractRemote(line string) (ip net.IP, port uint16,
return ip, port, protocol, nil return ip, port, protocol, nil
} }
var (
errDevLineFieldsCount = errors.New("dev line has not 2 fields as expected")
)
func extractInterfaceFromLine(line string) (intf string, err error) {
fields := strings.Fields(line)
if len(fields) != 2 { //nolint:gomnd
return "", fmt.Errorf("%w: %s", errDevLineFieldsCount, line)
}
return fields[1], nil
}

View File

@@ -11,21 +11,23 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func Test_extractConnectionFromLines(t *testing.T) { func Test_extractDataFromLines(t *testing.T) {
t.Parallel() t.Parallel()
testCases := map[string]struct { testCases := map[string]struct {
lines []string lines []string
connection models.Connection connection models.Connection
intf string
err error err error
}{ }{
"success": { "success": {
lines: []string{"bla bla", "proto tcp", "remote 1.2.3.4 1194 tcp"}, lines: []string{"bla bla", "proto tcp", "remote 1.2.3.4 1194 tcp", "dev tun6"},
connection: models.Connection{ connection: models.Connection{
IP: net.IPv4(1, 2, 3, 4), IP: net.IPv4(1, 2, 3, 4),
Port: 1194, Port: 1194,
Protocol: constants.TCP, Protocol: constants.TCP,
}, },
intf: "tun6",
}, },
"extraction error": { "extraction error": {
lines: []string{"bla bla", "proto bad", "remote 1.2.3.4 1194 tcp"}, lines: []string{"bla bla", "proto bad", "remote 1.2.3.4 1194 tcp"},
@@ -69,7 +71,7 @@ func Test_extractConnectionFromLines(t *testing.T) {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
t.Parallel() t.Parallel()
connection, err := extractConnectionFromLines(testCase.lines) connection, intf, err := extractDataFromLines(testCase.lines)
if testCase.err != nil { if testCase.err != nil {
require.Error(t, err) require.Error(t, err)
@@ -79,17 +81,21 @@ func Test_extractConnectionFromLines(t *testing.T) {
} }
assert.Equal(t, testCase.connection, connection) assert.Equal(t, testCase.connection, connection)
assert.Equal(t, testCase.intf, intf)
}) })
} }
} }
func Test_extractConnectionFromLine(t *testing.T) { func Test_extractDataFromLine(t *testing.T) {
t.Parallel() t.Parallel()
testCases := map[string]struct { testCases := map[string]struct {
line string line string
connection models.Connection ip net.IP
isErr error port uint16
protocol string
intf string
isErr error
}{ }{
"irrelevant line": { "irrelevant line": {
line: "bla bla", line: "bla bla",
@@ -99,22 +105,26 @@ func Test_extractConnectionFromLine(t *testing.T) {
isErr: errExtractProto, isErr: errExtractProto,
}, },
"extract proto success": { "extract proto success": {
line: "proto tcp", line: "proto tcp",
connection: models.Connection{ protocol: constants.TCP,
Protocol: constants.TCP, },
}, "extract intf error": {
line: "dev ",
isErr: errExtractDev,
},
"extract intf success": {
line: "dev tun3",
intf: "tun3",
}, },
"extract remote error": { "extract remote error": {
line: "remote bad", line: "remote bad",
isErr: errExtractRemote, isErr: errExtractRemote,
}, },
"extract remote success": { "extract remote success": {
line: "remote 1.2.3.4 1194 udp", line: "remote 1.2.3.4 1194 udp",
connection: models.Connection{ ip: net.IPv4(1, 2, 3, 4),
IP: net.IPv4(1, 2, 3, 4), port: 1194,
Port: 1194, protocol: constants.UDP,
Protocol: constants.UDP,
},
}, },
} }
@@ -123,7 +133,7 @@ func Test_extractConnectionFromLine(t *testing.T) {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
t.Parallel() t.Parallel()
connection, err := extractConnectionFromLine(testCase.line) ip, port, protocol, intf, err := extractDataFromLine(testCase.line)
if testCase.isErr != nil { if testCase.isErr != nil {
assert.ErrorIs(t, err, testCase.isErr) assert.ErrorIs(t, err, testCase.isErr)
@@ -131,7 +141,10 @@ func Test_extractConnectionFromLine(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
} }
assert.Equal(t, testCase.connection, connection) assert.Equal(t, testCase.ip, ip)
assert.Equal(t, testCase.port, port)
assert.Equal(t, testCase.protocol, protocol)
assert.Equal(t, testCase.intf, intf)
}) })
} }
} }
@@ -260,3 +273,44 @@ func Test_extractRemote(t *testing.T) {
}) })
} }
} }
func Test_extractInterface(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
line string
intf string
err error
}{
"found": {
line: "dev tun3",
intf: "tun3",
},
"not enough fields": {
line: "dev ",
err: errors.New("dev line has not 2 fields as expected: dev "),
},
"too many fields": {
line: "dev one two",
err: errors.New("dev line has not 2 fields as expected: dev one two"),
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
intf, err := extractInterfaceFromLine(testCase.line)
if testCase.err != nil {
require.Error(t, err)
assert.Equal(t, testCase.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
assert.Equal(t, testCase.intf, intf)
})
}
}

View File

@@ -11,7 +11,7 @@ import (
) )
func modifyCustomConfig(lines []string, settings configuration.OpenVPN, func modifyCustomConfig(lines []string, settings configuration.OpenVPN,
connection models.Connection) (modified []string) { connection models.Connection, intf string) (modified []string) {
// Remove some lines // Remove some lines
for _, line := range lines { for _, line := range lines {
switch { switch {
@@ -22,6 +22,7 @@ func modifyCustomConfig(lines []string, settings configuration.OpenVPN,
strings.HasPrefix(line, "user "), strings.HasPrefix(line, "user "),
strings.HasPrefix(line, "proto "), strings.HasPrefix(line, "proto "),
strings.HasPrefix(line, "remote "), strings.HasPrefix(line, "remote "),
strings.HasPrefix(line, "dev "),
settings.Cipher != "" && strings.HasPrefix(line, "cipher "), settings.Cipher != "" && strings.HasPrefix(line, "cipher "),
settings.Cipher != "" && strings.HasPrefix(line, "data-ciphers "), settings.Cipher != "" && strings.HasPrefix(line, "data-ciphers "),
settings.Auth != "" && strings.HasPrefix(line, "auth "), settings.Auth != "" && strings.HasPrefix(line, "auth "),
@@ -35,6 +36,7 @@ func modifyCustomConfig(lines []string, settings configuration.OpenVPN,
// Add values // Add values
modified = append(modified, connection.OpenVPNProtoLine()) modified = append(modified, connection.OpenVPNProtoLine())
modified = append(modified, connection.OpenVPNRemoteLine()) modified = append(modified, connection.OpenVPNRemoteLine())
modified = append(modified, "dev "+intf)
modified = append(modified, "mute-replay-warnings") modified = append(modified, "mute-replay-warnings")
modified = append(modified, "auth-nocache") modified = append(modified, "auth-nocache")
modified = append(modified, "pull-filter ignore \"auth-token\"") // prevent auth failed loop modified = append(modified, "pull-filter ignore \"auth-token\"") // prevent auth failed loop

View File

@@ -17,6 +17,7 @@ func Test_modifyCustomConfig(t *testing.T) {
lines []string lines []string
settings configuration.OpenVPN settings configuration.OpenVPN
connection models.Connection connection models.Connection
intf string
modified []string modified []string
}{ }{
"mixed": { "mixed": {
@@ -41,10 +42,12 @@ func Test_modifyCustomConfig(t *testing.T) {
Port: 1194, Port: 1194,
Protocol: constants.UDP, Protocol: constants.UDP,
}, },
intf: "tun3",
modified: []string{ modified: []string{
"keep me here", "keep me here",
"proto udp", "proto udp",
"remote 1.2.3.4 1194", "remote 1.2.3.4 1194",
"dev tun3",
"mute-replay-warnings", "mute-replay-warnings",
"auth-nocache", "auth-nocache",
"pull-filter ignore \"auth-token\"", "pull-filter ignore \"auth-token\"",
@@ -69,7 +72,7 @@ func Test_modifyCustomConfig(t *testing.T) {
t.Parallel() t.Parallel()
modified := modifyCustomConfig(testCase.lines, modified := modifyCustomConfig(testCase.lines,
testCase.settings, testCase.connection) testCase.settings, testCase.connection, testCase.intf)
assert.Equal(t, testCase.modified, modified) assert.Equal(t, testCase.modified, modified)
}) })

View File

@@ -6,7 +6,6 @@ import (
"fmt" "fmt"
"net" "net"
"github.com/qdm12/gluetun/internal/constants"
"github.com/vishvananda/netlink" "github.com/vishvananda/netlink"
) )
@@ -242,10 +241,10 @@ func (r *routing) VPNDestinationIP() (ip net.IP, err error) {
} }
type VPNLocalGatewayIPGetter interface { type VPNLocalGatewayIPGetter interface {
VPNLocalGatewayIP() (ip net.IP, err error) VPNLocalGatewayIP(vpnIntf string) (ip net.IP, err error)
} }
func (r *routing) VPNLocalGatewayIP() (ip net.IP, err error) { func (r *routing) VPNLocalGatewayIP(vpnIntf string) (ip net.IP, err error) {
routes, err := netlink.RouteList(nil, netlink.FAMILY_ALL) routes, err := netlink.RouteList(nil, netlink.FAMILY_ALL)
if err != nil { if err != nil {
return nil, fmt.Errorf("%w: %s", ErrRoutesList, err) return nil, fmt.Errorf("%w: %s", ErrRoutesList, err)
@@ -256,7 +255,7 @@ func (r *routing) VPNLocalGatewayIP() (ip net.IP, err error) {
return nil, fmt.Errorf("%w: %s", ErrLinkByIndex, err) return nil, fmt.Errorf("%w: %s", ErrLinkByIndex, err)
} }
interfaceName := link.Attrs().Name interfaceName := link.Attrs().Name
if interfaceName == string(constants.TUN) && if interfaceName == vpnIntf &&
route.Dst != nil && route.Dst != nil &&
route.Dst.IP.Equal(net.IP{0, 0, 0, 0}) { route.Dst.IP.Equal(net.IP{0, 0, 0, 0}) {
return route.Gw, nil return route.Gw, nil

View File

@@ -29,14 +29,16 @@ func setupOpenVPN(ctx context.Context, fw firewall.VPNConnectionSetter,
settings configuration.VPN, starter command.Starter, logger logging.Logger) ( settings configuration.VPN, starter command.Starter, logger logging.Logger) (
runner vpnRunner, serverName string, err error) { runner vpnRunner, serverName string, err error) {
var connection models.Connection var connection models.Connection
var netInterface string
var lines []string var lines []string
if settings.OpenVPN.Config == "" { if settings.OpenVPN.Config == "" {
netInterface = settings.OpenVPN.Interface
connection, err = providerConf.GetConnection(settings.Provider.ServerSelection) connection, err = providerConf.GetConnection(settings.Provider.ServerSelection)
if err == nil { if err == nil {
lines = providerConf.BuildConf(connection, settings.OpenVPN) lines = providerConf.BuildConf(connection, settings.OpenVPN)
} }
} else { } else {
lines, connection, err = custom.BuildConfig(settings.OpenVPN) lines, connection, netInterface, err = custom.BuildConfig(settings.OpenVPN)
} }
if err != nil { if err != nil {
return nil, "", fmt.Errorf("%w: %s", errBuildConfig, err) return nil, "", fmt.Errorf("%w: %s", errBuildConfig, err)
@@ -53,7 +55,7 @@ func setupOpenVPN(ctx context.Context, fw firewall.VPNConnectionSetter,
} }
} }
if err := fw.SetVPNConnection(ctx, connection); err != nil { if err := fw.SetVPNConnection(ctx, connection, netInterface); err != nil {
return nil, "", fmt.Errorf("%w: %s", errFirewall, err) return nil, "", fmt.Errorf("%w: %s", errFirewall, err)
} }

View File

@@ -6,9 +6,7 @@ import (
"fmt" "fmt"
"time" "time"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/portforward" "github.com/qdm12/gluetun/internal/portforward"
"github.com/qdm12/gluetun/internal/provider"
) )
var ( var (
@@ -16,24 +14,23 @@ var (
errStartPortForwarding = errors.New("cannot start port forwarding") errStartPortForwarding = errors.New("cannot start port forwarding")
) )
func (l *Loop) startPortForwarding(ctx context.Context, enabled bool, func (l *Loop) startPortForwarding(ctx context.Context, data tunnelUpData) (err error) {
portForwarder provider.PortForwarder, serverName string) (err error) { if !data.portForwarding {
if !enabled {
return nil return nil
} }
// only used for PIA for now // only used for PIA for now
gateway, err := l.routing.VPNLocalGatewayIP() gateway, err := l.routing.VPNLocalGatewayIP(data.vpnIntf)
if err != nil { if err != nil {
return fmt.Errorf("%w: %s", errObtainVPNLocalGateway, err) return fmt.Errorf("%w: %s", errObtainVPNLocalGateway, err)
} }
l.logger.Info("VPN gateway IP address: " + gateway.String()) l.logger.Info("VPN gateway IP address: " + gateway.String())
pfData := portforward.StartData{ pfData := portforward.StartData{
PortForwarder: portForwarder, PortForwarder: data.portForwarder,
Gateway: gateway, Gateway: gateway,
ServerName: serverName, ServerName: data.serverName,
Interface: constants.TUN, Interface: data.vpnIntf,
} }
_, err = l.portForward.Start(ctx, pfData) _, err = l.portForward.Start(ctx, pfData)
if err != nil { if err != nil {

View File

@@ -11,6 +11,7 @@ import (
type tunnelUpData struct { type tunnelUpData struct {
// Port forwarding // Port forwarding
portForwarding bool portForwarding bool
vpnIntf string
serverName string serverName string
portForwarder provider.PortForwarder portForwarder provider.PortForwarder
} }
@@ -39,7 +40,7 @@ func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) {
} }
} }
err = l.startPortForwarding(ctx, data.portForwarding, data.portForwarder, data.serverName) err = l.startPortForwarding(ctx, data)
if err != nil { if err != nil {
l.logger.Error(err.Error()) l.logger.Error(err.Error())
} }