Feat: OPENVPN_INTERFACE defaulting to tun0
- Fix: custom config with custom network interface name for firewall - Keep VPN tunnel interface in firewall state - Vul fix: only allow traffic through vpn interface when needed - Adapt code to adapt to network interface name - Remove outdated TUN and TAP constants
This commit is contained in:
@@ -76,6 +76,7 @@ ENV VPNSP=pia \
|
|||||||
OPENVPN_TARGET_IP= \
|
OPENVPN_TARGET_IP= \
|
||||||
OPENVPN_IPV6=off \
|
OPENVPN_IPV6=off \
|
||||||
OPENVPN_CUSTOM_CONFIG= \
|
OPENVPN_CUSTOM_CONFIG= \
|
||||||
|
OPENVPN_INTERFACE=tun0 \
|
||||||
TZ= \
|
TZ= \
|
||||||
PUID= \
|
PUID= \
|
||||||
PGID= \
|
PGID= \
|
||||||
|
|||||||
@@ -290,7 +290,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, vpnPort := range allSettings.Firewall.VPNInputPorts {
|
for _, vpnPort := range allSettings.Firewall.VPNInputPorts {
|
||||||
err = firewallConf.SetAllowedPort(ctx, vpnPort, string(constants.TUN))
|
err = firewallConf.SetAllowedPort(ctx, vpnPort, allSettings.VPN.OpenVPN.Interface)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
package configuration
|
package configuration
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -26,6 +28,7 @@ type OpenVPN struct {
|
|||||||
EncPreset string `json:"encryption_preset"` // PIA
|
EncPreset string `json:"encryption_preset"` // PIA
|
||||||
IPv6 bool `json:"ipv6"` // Mullvad
|
IPv6 bool `json:"ipv6"` // Mullvad
|
||||||
ProcUser string `json:"procuser"` // Process username
|
ProcUser string `json:"procuser"` // Process username
|
||||||
|
Interface string `json:"interface"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (settings *OpenVPN) String() string {
|
func (settings *OpenVPN) String() string {
|
||||||
@@ -39,6 +42,8 @@ func (settings *OpenVPN) lines() (lines []string) {
|
|||||||
|
|
||||||
lines = append(lines, indent+lastIndent+"Verbosity level: "+strconv.Itoa(settings.Verbosity))
|
lines = append(lines, indent+lastIndent+"Verbosity level: "+strconv.Itoa(settings.Verbosity))
|
||||||
|
|
||||||
|
lines = append(lines, indent+lastIndent+"Network interface: "+settings.Interface)
|
||||||
|
|
||||||
if len(settings.Flags) > 0 {
|
if len(settings.Flags) > 0 {
|
||||||
lines = append(lines, indent+lastIndent+"Flags: "+strings.Join(settings.Flags, " "))
|
lines = append(lines, indent+lastIndent+"Flags: "+strings.Join(settings.Flags, " "))
|
||||||
}
|
}
|
||||||
@@ -148,6 +153,11 @@ func (settings *OpenVPN) read(r reader, serviceProvider string) (err error) {
|
|||||||
return fmt.Errorf("environment variable OPENVPN_IPV6: %w", err)
|
return fmt.Errorf("environment variable OPENVPN_IPV6: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
settings.Interface, err = readInterface(r.env)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
settings.EncPreset, err = getPIAEncryptionPreset(r)
|
settings.EncPreset, err = getPIAEncryptionPreset(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -173,3 +183,22 @@ func readProtocol(env params.Env) (tcp bool, err error) {
|
|||||||
}
|
}
|
||||||
return protocol == constants.TCP, nil
|
return protocol == constants.TCP, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const openvpnIntfRegexString = `^.*[0-9]$`
|
||||||
|
|
||||||
|
var openvpnIntfRegexp = regexp.MustCompile(openvpnIntfRegexString)
|
||||||
|
var errInterfaceNameNotValid = errors.New("interface name is not valid")
|
||||||
|
|
||||||
|
func readInterface(env params.Env) (intf string, err error) {
|
||||||
|
intf, err = env.Get("OPENVPN_INTERFACE", params.Default("tun0"))
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("environment variable OPENVPN_INTERFACE: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !openvpnIntfRegexp.MatchString(intf) {
|
||||||
|
return "", fmt.Errorf("%w: does not match regex %s: %s",
|
||||||
|
errInterfaceNameNotValid, openvpnIntfRegexString, intf)
|
||||||
|
}
|
||||||
|
|
||||||
|
return intf, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -29,7 +29,8 @@ func Test_OpenVPN_JSON(t *testing.T) {
|
|||||||
"version": "",
|
"version": "",
|
||||||
"encryption_preset": "",
|
"encryption_preset": "",
|
||||||
"ipv6": false,
|
"ipv6": false,
|
||||||
"procuser": ""
|
"procuser": "",
|
||||||
|
"interface": ""
|
||||||
}`, string(data))
|
}`, string(data))
|
||||||
var out OpenVPN
|
var out OpenVPN
|
||||||
err = json.Unmarshal(data, &out)
|
err = json.Unmarshal(data, &out)
|
||||||
|
|||||||
@@ -22,7 +22,8 @@ func Test_Settings_lines(t *testing.T) {
|
|||||||
Name: constants.Mullvad,
|
Name: constants.Mullvad,
|
||||||
},
|
},
|
||||||
OpenVPN: OpenVPN{
|
OpenVPN: OpenVPN{
|
||||||
Version: constants.Openvpn25,
|
Version: constants.Openvpn25,
|
||||||
|
Interface: "tun",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -33,6 +34,7 @@ func Test_Settings_lines(t *testing.T) {
|
|||||||
" |--OpenVPN:",
|
" |--OpenVPN:",
|
||||||
" |--Version: 2.5",
|
" |--Version: 2.5",
|
||||||
" |--Verbosity level: 0",
|
" |--Verbosity level: 0",
|
||||||
|
" |--Network interface: tun",
|
||||||
" |--Mullvad settings:",
|
" |--Mullvad settings:",
|
||||||
" |--OpenVPN selection:",
|
" |--OpenVPN selection:",
|
||||||
" |--Protocol: udp",
|
" |--Protocol: udp",
|
||||||
|
|||||||
@@ -1,10 +1,5 @@
|
|||||||
package constants
|
package constants
|
||||||
|
|
||||||
const (
|
|
||||||
TUN = "tun0"
|
|
||||||
TAP = "tap0"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
AES128cbc = "aes-128-cbc"
|
AES128cbc = "aes-128-cbc"
|
||||||
AES256cbc = "aes-256-cbc"
|
AES256cbc = "aes-256-cbc"
|
||||||
|
|||||||
@@ -4,8 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/constants"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -109,9 +107,9 @@ func (c *Config) enable(ctx context.Context) (err error) {
|
|||||||
if err = c.acceptOutputTrafficToVPN(ctx, c.defaultInterface, c.vpnConnection, remove); err != nil {
|
if err = c.acceptOutputTrafficToVPN(ctx, c.defaultInterface, c.vpnConnection, remove); err != nil {
|
||||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||||
}
|
}
|
||||||
}
|
if err = c.acceptOutputThroughInterface(ctx, c.vpnIntf, remove); err != nil {
|
||||||
if err = c.acceptOutputThroughInterface(ctx, string(constants.TUN), remove); err != nil {
|
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, network := range c.localNetworks {
|
for _, network := range c.localNetworks {
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ type Config struct { //nolint:maligned
|
|||||||
// State
|
// State
|
||||||
enabled bool
|
enabled bool
|
||||||
vpnConnection models.Connection
|
vpnConnection models.Connection
|
||||||
|
vpnIntf string
|
||||||
outboundSubnets []net.IPNet
|
outboundSubnets []net.IPNet
|
||||||
allowedInputPorts map[uint16]string // port to interface mapping
|
allowedInputPorts map[uint16]string // port to interface mapping
|
||||||
stateMutex sync.Mutex
|
stateMutex sync.Mutex
|
||||||
|
|||||||
@@ -8,10 +8,12 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type VPNConnectionSetter interface {
|
type VPNConnectionSetter interface {
|
||||||
SetVPNConnection(ctx context.Context, connection models.Connection) error
|
SetVPNConnection(ctx context.Context,
|
||||||
|
connection models.Connection, vpnIntf string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Config) SetVPNConnection(ctx context.Context, connection models.Connection) (err error) {
|
func (c *Config) SetVPNConnection(ctx context.Context,
|
||||||
|
connection models.Connection, vpnIntf string) (err error) {
|
||||||
c.stateMutex.Lock()
|
c.stateMutex.Lock()
|
||||||
defer c.stateMutex.Unlock()
|
defer c.stateMutex.Unlock()
|
||||||
|
|
||||||
@@ -34,10 +36,25 @@ func (c *Config) SetVPNConnection(ctx context.Context, connection models.Connect
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
c.vpnConnection = models.Connection{}
|
c.vpnConnection = models.Connection{}
|
||||||
|
|
||||||
|
if c.vpnIntf != "" {
|
||||||
|
if err = c.acceptOutputThroughInterface(ctx, c.vpnIntf, remove); err != nil {
|
||||||
|
c.logger.Error("cannot remove outdated VPN interface from firewall: " + err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.vpnIntf = ""
|
||||||
|
|
||||||
remove = false
|
remove = false
|
||||||
|
|
||||||
if err := c.acceptOutputTrafficToVPN(ctx, c.defaultInterface, connection, remove); err != nil {
|
if err := c.acceptOutputTrafficToVPN(ctx, c.defaultInterface, connection, remove); err != nil {
|
||||||
return fmt.Errorf("cannot set VPN connection through firewall: %w", err)
|
return fmt.Errorf("cannot set VPN connection through firewall: %w", err)
|
||||||
}
|
}
|
||||||
c.vpnConnection = connection
|
c.vpnConnection = connection
|
||||||
|
|
||||||
|
if err = c.acceptOutputThroughInterface(ctx, vpnIntf, remove); err != nil {
|
||||||
|
return fmt.Errorf("cannot accept output traffic through interface %s: %w", vpnIntf, err)
|
||||||
|
}
|
||||||
|
c.vpnIntf = vpnIntf
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -33,18 +33,15 @@ func (c Connection) OpenVPNProtoLine() (line string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UpdateEmptyWith updates each field of the connection where the
|
// UpdateEmptyWith updates each field of the connection where the
|
||||||
// value is not set using the value from the other connection.
|
// value is not set using the value given as arguments.
|
||||||
func (c *Connection) UpdateEmptyWith(connection Connection) {
|
func (c *Connection) UpdateEmptyWith(ip net.IP, port uint16, protocol string) {
|
||||||
if c.IP == nil {
|
if c.IP == nil {
|
||||||
c.IP = connection.IP
|
c.IP = ip
|
||||||
}
|
}
|
||||||
if c.Port == 0 {
|
if c.Port == 0 {
|
||||||
c.Port = connection.Port
|
c.Port = port
|
||||||
}
|
}
|
||||||
if c.Protocol == "" {
|
if c.Protocol == "" {
|
||||||
c.Protocol = connection.Protocol
|
c.Protocol = protocol
|
||||||
}
|
|
||||||
if c.Hostname == "" {
|
|
||||||
c.Hostname = connection.Hostname
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,18 +14,22 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func BuildConfig(settings configuration.OpenVPN) (
|
func BuildConfig(settings configuration.OpenVPN) (
|
||||||
lines []string, connection models.Connection, err error) {
|
lines []string, connection models.Connection, intf string, err error) {
|
||||||
lines, err = readCustomConfigLines(settings.Config)
|
lines, err = readCustomConfigLines(settings.Config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, connection, fmt.Errorf("%w: %s", ErrReadCustomConfig, err)
|
return nil, connection, "", fmt.Errorf("%w: %s", ErrReadCustomConfig, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
connection, err = extractConnectionFromLines(lines)
|
connection, intf, err = extractDataFromLines(lines)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, connection, fmt.Errorf("%w: %s", ErrExtractConnection, err)
|
return nil, connection, "", fmt.Errorf("%w: %s", ErrExtractConnection, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
lines = modifyCustomConfig(lines, settings, connection)
|
if intf == "" {
|
||||||
|
intf = settings.Interface
|
||||||
|
}
|
||||||
|
|
||||||
return lines, connection, nil
|
lines = modifyCustomConfig(lines, settings, connection, intf)
|
||||||
|
|
||||||
|
return lines, connection, intf, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -27,18 +27,20 @@ func Test_BuildConfig(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
settings := configuration.OpenVPN{
|
settings := configuration.OpenVPN{
|
||||||
Cipher: "cipher",
|
Cipher: "cipher",
|
||||||
MSSFix: 999,
|
MSSFix: 999,
|
||||||
Config: file.Name(),
|
Config: file.Name(),
|
||||||
|
Interface: "tun",
|
||||||
}
|
}
|
||||||
|
|
||||||
lines, connection, err := BuildConfig(settings)
|
lines, connection, intf, err := BuildConfig(settings)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
expectedLines := []string{
|
expectedLines := []string{
|
||||||
"keep me",
|
"keep me",
|
||||||
"proto udp",
|
"proto udp",
|
||||||
"remote 1.9.8.7 1194",
|
"remote 1.9.8.7 1194",
|
||||||
|
"dev tun",
|
||||||
"mute-replay-warnings",
|
"mute-replay-warnings",
|
||||||
"auth-nocache",
|
"auth-nocache",
|
||||||
"pull-filter ignore \"auth-token\"",
|
"pull-filter ignore \"auth-token\"",
|
||||||
@@ -60,4 +62,6 @@ func Test_BuildConfig(t *testing.T) {
|
|||||||
Protocol: constants.UDP,
|
Protocol: constants.UDP,
|
||||||
}
|
}
|
||||||
assert.Equal(t, expectedConnection, connection)
|
assert.Equal(t, expectedConnection, connection)
|
||||||
|
|
||||||
|
assert.Equal(t, "tun", intf)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,23 +15,24 @@ var (
|
|||||||
errRemoteLineNotFound = errors.New("remote line not found")
|
errRemoteLineNotFound = errors.New("remote line not found")
|
||||||
)
|
)
|
||||||
|
|
||||||
// extractConnectionFromLines always takes the first remote line only.
|
func extractDataFromLines(lines []string) (
|
||||||
func extractConnectionFromLines(lines []string) (
|
connection models.Connection, intf string, err error) {
|
||||||
connection models.Connection, err error) {
|
|
||||||
for i, line := range lines {
|
for i, line := range lines {
|
||||||
newConnectionData, err := extractConnectionFromLine(line)
|
ip, port, protocol, intfFound, err := extractDataFromLine(line)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return connection, fmt.Errorf("on line %d: %w", i+1, err)
|
return connection, "", fmt.Errorf("on line %d: %w", i+1, err)
|
||||||
}
|
}
|
||||||
connection.UpdateEmptyWith(newConnectionData)
|
|
||||||
|
|
||||||
if connection.Protocol != "" && connection.IP != nil {
|
intf = intfFound
|
||||||
|
connection.UpdateEmptyWith(ip, port, protocol)
|
||||||
|
|
||||||
|
if connection.Protocol != "" && connection.IP != nil && intf != "" {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if connection.IP == nil {
|
if connection.IP == nil {
|
||||||
return connection, errRemoteLineNotFound
|
return connection, "", errRemoteLineNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
if connection.Protocol == "" {
|
if connection.Protocol == "" {
|
||||||
@@ -45,32 +46,41 @@ func extractConnectionFromLines(lines []string) (
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return connection, nil
|
return connection, intf, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
errExtractProto = errors.New("failed extracting protocol from proto line")
|
errExtractProto = errors.New("failed extracting protocol from proto line")
|
||||||
errExtractRemote = errors.New("failed extracting protocol from remote line")
|
errExtractRemote = errors.New("failed extracting from remote line")
|
||||||
|
errExtractDev = errors.New("failed extracting network interface from dev line")
|
||||||
)
|
)
|
||||||
|
|
||||||
func extractConnectionFromLine(line string) (
|
func extractDataFromLine(line string) (
|
||||||
connection models.Connection, err error) {
|
ip net.IP, port uint16, protocol, intf string, err error) {
|
||||||
switch {
|
switch {
|
||||||
case strings.HasPrefix(line, "proto "):
|
case strings.HasPrefix(line, "proto "):
|
||||||
connection.Protocol, err = extractProto(line)
|
protocol, err = extractProto(line)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return connection, fmt.Errorf("%w: %s", errExtractProto, err)
|
return nil, 0, "", "", fmt.Errorf("%w: %s", errExtractProto, err)
|
||||||
}
|
}
|
||||||
|
return nil, 0, protocol, "", nil
|
||||||
|
|
||||||
// only take the first remote line
|
case strings.HasPrefix(line, "remote "):
|
||||||
case strings.HasPrefix(line, "remote ") && connection.IP == nil:
|
ip, port, protocol, err = extractRemote(line)
|
||||||
connection.IP, connection.Port, connection.Protocol, err = extractRemote(line)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return connection, fmt.Errorf("%w: %s", errExtractRemote, err)
|
return nil, 0, "", "", fmt.Errorf("%w: %s", errExtractRemote, err)
|
||||||
}
|
}
|
||||||
|
return ip, port, protocol, "", nil
|
||||||
|
|
||||||
|
case strings.HasPrefix(line, "dev "):
|
||||||
|
intf, err = extractInterfaceFromLine(line)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, "", "", fmt.Errorf("%w: %s", errExtractDev, err)
|
||||||
|
}
|
||||||
|
return nil, 0, "", intf, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return connection, nil
|
return nil, 0, "", "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -137,3 +147,16 @@ func extractRemote(line string) (ip net.IP, port uint16,
|
|||||||
|
|
||||||
return ip, port, protocol, nil
|
return ip, port, protocol, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
errDevLineFieldsCount = errors.New("dev line has not 2 fields as expected")
|
||||||
|
)
|
||||||
|
|
||||||
|
func extractInterfaceFromLine(line string) (intf string, err error) {
|
||||||
|
fields := strings.Fields(line)
|
||||||
|
if len(fields) != 2 { //nolint:gomnd
|
||||||
|
return "", fmt.Errorf("%w: %s", errDevLineFieldsCount, line)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fields[1], nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -11,21 +11,23 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_extractConnectionFromLines(t *testing.T) {
|
func Test_extractDataFromLines(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
lines []string
|
lines []string
|
||||||
connection models.Connection
|
connection models.Connection
|
||||||
|
intf string
|
||||||
err error
|
err error
|
||||||
}{
|
}{
|
||||||
"success": {
|
"success": {
|
||||||
lines: []string{"bla bla", "proto tcp", "remote 1.2.3.4 1194 tcp"},
|
lines: []string{"bla bla", "proto tcp", "remote 1.2.3.4 1194 tcp", "dev tun6"},
|
||||||
connection: models.Connection{
|
connection: models.Connection{
|
||||||
IP: net.IPv4(1, 2, 3, 4),
|
IP: net.IPv4(1, 2, 3, 4),
|
||||||
Port: 1194,
|
Port: 1194,
|
||||||
Protocol: constants.TCP,
|
Protocol: constants.TCP,
|
||||||
},
|
},
|
||||||
|
intf: "tun6",
|
||||||
},
|
},
|
||||||
"extraction error": {
|
"extraction error": {
|
||||||
lines: []string{"bla bla", "proto bad", "remote 1.2.3.4 1194 tcp"},
|
lines: []string{"bla bla", "proto bad", "remote 1.2.3.4 1194 tcp"},
|
||||||
@@ -69,7 +71,7 @@ func Test_extractConnectionFromLines(t *testing.T) {
|
|||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
connection, err := extractConnectionFromLines(testCase.lines)
|
connection, intf, err := extractDataFromLines(testCase.lines)
|
||||||
|
|
||||||
if testCase.err != nil {
|
if testCase.err != nil {
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
@@ -79,17 +81,21 @@ func Test_extractConnectionFromLines(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
assert.Equal(t, testCase.connection, connection)
|
assert.Equal(t, testCase.connection, connection)
|
||||||
|
assert.Equal(t, testCase.intf, intf)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_extractConnectionFromLine(t *testing.T) {
|
func Test_extractDataFromLine(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
line string
|
line string
|
||||||
connection models.Connection
|
ip net.IP
|
||||||
isErr error
|
port uint16
|
||||||
|
protocol string
|
||||||
|
intf string
|
||||||
|
isErr error
|
||||||
}{
|
}{
|
||||||
"irrelevant line": {
|
"irrelevant line": {
|
||||||
line: "bla bla",
|
line: "bla bla",
|
||||||
@@ -99,22 +105,26 @@ func Test_extractConnectionFromLine(t *testing.T) {
|
|||||||
isErr: errExtractProto,
|
isErr: errExtractProto,
|
||||||
},
|
},
|
||||||
"extract proto success": {
|
"extract proto success": {
|
||||||
line: "proto tcp",
|
line: "proto tcp",
|
||||||
connection: models.Connection{
|
protocol: constants.TCP,
|
||||||
Protocol: constants.TCP,
|
},
|
||||||
},
|
"extract intf error": {
|
||||||
|
line: "dev ",
|
||||||
|
isErr: errExtractDev,
|
||||||
|
},
|
||||||
|
"extract intf success": {
|
||||||
|
line: "dev tun3",
|
||||||
|
intf: "tun3",
|
||||||
},
|
},
|
||||||
"extract remote error": {
|
"extract remote error": {
|
||||||
line: "remote bad",
|
line: "remote bad",
|
||||||
isErr: errExtractRemote,
|
isErr: errExtractRemote,
|
||||||
},
|
},
|
||||||
"extract remote success": {
|
"extract remote success": {
|
||||||
line: "remote 1.2.3.4 1194 udp",
|
line: "remote 1.2.3.4 1194 udp",
|
||||||
connection: models.Connection{
|
ip: net.IPv4(1, 2, 3, 4),
|
||||||
IP: net.IPv4(1, 2, 3, 4),
|
port: 1194,
|
||||||
Port: 1194,
|
protocol: constants.UDP,
|
||||||
Protocol: constants.UDP,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -123,7 +133,7 @@ func Test_extractConnectionFromLine(t *testing.T) {
|
|||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
connection, err := extractConnectionFromLine(testCase.line)
|
ip, port, protocol, intf, err := extractDataFromLine(testCase.line)
|
||||||
|
|
||||||
if testCase.isErr != nil {
|
if testCase.isErr != nil {
|
||||||
assert.ErrorIs(t, err, testCase.isErr)
|
assert.ErrorIs(t, err, testCase.isErr)
|
||||||
@@ -131,7 +141,10 @@ func Test_extractConnectionFromLine(t *testing.T) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Equal(t, testCase.connection, connection)
|
assert.Equal(t, testCase.ip, ip)
|
||||||
|
assert.Equal(t, testCase.port, port)
|
||||||
|
assert.Equal(t, testCase.protocol, protocol)
|
||||||
|
assert.Equal(t, testCase.intf, intf)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -260,3 +273,44 @@ func Test_extractRemote(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_extractInterface(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
testCases := map[string]struct {
|
||||||
|
line string
|
||||||
|
intf string
|
||||||
|
err error
|
||||||
|
}{
|
||||||
|
"found": {
|
||||||
|
line: "dev tun3",
|
||||||
|
intf: "tun3",
|
||||||
|
},
|
||||||
|
"not enough fields": {
|
||||||
|
line: "dev ",
|
||||||
|
err: errors.New("dev line has not 2 fields as expected: dev "),
|
||||||
|
},
|
||||||
|
"too many fields": {
|
||||||
|
line: "dev one two",
|
||||||
|
err: errors.New("dev line has not 2 fields as expected: dev one two"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, testCase := range testCases {
|
||||||
|
testCase := testCase
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
intf, err := extractInterfaceFromLine(testCase.line)
|
||||||
|
|
||||||
|
if testCase.err != nil {
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Equal(t, testCase.err.Error(), err.Error())
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, testCase.intf, intf)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func modifyCustomConfig(lines []string, settings configuration.OpenVPN,
|
func modifyCustomConfig(lines []string, settings configuration.OpenVPN,
|
||||||
connection models.Connection) (modified []string) {
|
connection models.Connection, intf string) (modified []string) {
|
||||||
// Remove some lines
|
// Remove some lines
|
||||||
for _, line := range lines {
|
for _, line := range lines {
|
||||||
switch {
|
switch {
|
||||||
@@ -22,6 +22,7 @@ func modifyCustomConfig(lines []string, settings configuration.OpenVPN,
|
|||||||
strings.HasPrefix(line, "user "),
|
strings.HasPrefix(line, "user "),
|
||||||
strings.HasPrefix(line, "proto "),
|
strings.HasPrefix(line, "proto "),
|
||||||
strings.HasPrefix(line, "remote "),
|
strings.HasPrefix(line, "remote "),
|
||||||
|
strings.HasPrefix(line, "dev "),
|
||||||
settings.Cipher != "" && strings.HasPrefix(line, "cipher "),
|
settings.Cipher != "" && strings.HasPrefix(line, "cipher "),
|
||||||
settings.Cipher != "" && strings.HasPrefix(line, "data-ciphers "),
|
settings.Cipher != "" && strings.HasPrefix(line, "data-ciphers "),
|
||||||
settings.Auth != "" && strings.HasPrefix(line, "auth "),
|
settings.Auth != "" && strings.HasPrefix(line, "auth "),
|
||||||
@@ -35,6 +36,7 @@ func modifyCustomConfig(lines []string, settings configuration.OpenVPN,
|
|||||||
// Add values
|
// Add values
|
||||||
modified = append(modified, connection.OpenVPNProtoLine())
|
modified = append(modified, connection.OpenVPNProtoLine())
|
||||||
modified = append(modified, connection.OpenVPNRemoteLine())
|
modified = append(modified, connection.OpenVPNRemoteLine())
|
||||||
|
modified = append(modified, "dev "+intf)
|
||||||
modified = append(modified, "mute-replay-warnings")
|
modified = append(modified, "mute-replay-warnings")
|
||||||
modified = append(modified, "auth-nocache")
|
modified = append(modified, "auth-nocache")
|
||||||
modified = append(modified, "pull-filter ignore \"auth-token\"") // prevent auth failed loop
|
modified = append(modified, "pull-filter ignore \"auth-token\"") // prevent auth failed loop
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ func Test_modifyCustomConfig(t *testing.T) {
|
|||||||
lines []string
|
lines []string
|
||||||
settings configuration.OpenVPN
|
settings configuration.OpenVPN
|
||||||
connection models.Connection
|
connection models.Connection
|
||||||
|
intf string
|
||||||
modified []string
|
modified []string
|
||||||
}{
|
}{
|
||||||
"mixed": {
|
"mixed": {
|
||||||
@@ -41,10 +42,12 @@ func Test_modifyCustomConfig(t *testing.T) {
|
|||||||
Port: 1194,
|
Port: 1194,
|
||||||
Protocol: constants.UDP,
|
Protocol: constants.UDP,
|
||||||
},
|
},
|
||||||
|
intf: "tun3",
|
||||||
modified: []string{
|
modified: []string{
|
||||||
"keep me here",
|
"keep me here",
|
||||||
"proto udp",
|
"proto udp",
|
||||||
"remote 1.2.3.4 1194",
|
"remote 1.2.3.4 1194",
|
||||||
|
"dev tun3",
|
||||||
"mute-replay-warnings",
|
"mute-replay-warnings",
|
||||||
"auth-nocache",
|
"auth-nocache",
|
||||||
"pull-filter ignore \"auth-token\"",
|
"pull-filter ignore \"auth-token\"",
|
||||||
@@ -69,7 +72,7 @@ func Test_modifyCustomConfig(t *testing.T) {
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
modified := modifyCustomConfig(testCase.lines,
|
modified := modifyCustomConfig(testCase.lines,
|
||||||
testCase.settings, testCase.connection)
|
testCase.settings, testCase.connection, testCase.intf)
|
||||||
|
|
||||||
assert.Equal(t, testCase.modified, modified)
|
assert.Equal(t, testCase.modified, modified)
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/constants"
|
|
||||||
"github.com/vishvananda/netlink"
|
"github.com/vishvananda/netlink"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -242,10 +241,10 @@ func (r *routing) VPNDestinationIP() (ip net.IP, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type VPNLocalGatewayIPGetter interface {
|
type VPNLocalGatewayIPGetter interface {
|
||||||
VPNLocalGatewayIP() (ip net.IP, err error)
|
VPNLocalGatewayIP(vpnIntf string) (ip net.IP, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *routing) VPNLocalGatewayIP() (ip net.IP, err error) {
|
func (r *routing) VPNLocalGatewayIP(vpnIntf string) (ip net.IP, err error) {
|
||||||
routes, err := netlink.RouteList(nil, netlink.FAMILY_ALL)
|
routes, err := netlink.RouteList(nil, netlink.FAMILY_ALL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("%w: %s", ErrRoutesList, err)
|
return nil, fmt.Errorf("%w: %s", ErrRoutesList, err)
|
||||||
@@ -256,7 +255,7 @@ func (r *routing) VPNLocalGatewayIP() (ip net.IP, err error) {
|
|||||||
return nil, fmt.Errorf("%w: %s", ErrLinkByIndex, err)
|
return nil, fmt.Errorf("%w: %s", ErrLinkByIndex, err)
|
||||||
}
|
}
|
||||||
interfaceName := link.Attrs().Name
|
interfaceName := link.Attrs().Name
|
||||||
if interfaceName == string(constants.TUN) &&
|
if interfaceName == vpnIntf &&
|
||||||
route.Dst != nil &&
|
route.Dst != nil &&
|
||||||
route.Dst.IP.Equal(net.IP{0, 0, 0, 0}) {
|
route.Dst.IP.Equal(net.IP{0, 0, 0, 0}) {
|
||||||
return route.Gw, nil
|
return route.Gw, nil
|
||||||
|
|||||||
@@ -29,14 +29,16 @@ func setupOpenVPN(ctx context.Context, fw firewall.VPNConnectionSetter,
|
|||||||
settings configuration.VPN, starter command.Starter, logger logging.Logger) (
|
settings configuration.VPN, starter command.Starter, logger logging.Logger) (
|
||||||
runner vpnRunner, serverName string, err error) {
|
runner vpnRunner, serverName string, err error) {
|
||||||
var connection models.Connection
|
var connection models.Connection
|
||||||
|
var netInterface string
|
||||||
var lines []string
|
var lines []string
|
||||||
if settings.OpenVPN.Config == "" {
|
if settings.OpenVPN.Config == "" {
|
||||||
|
netInterface = settings.OpenVPN.Interface
|
||||||
connection, err = providerConf.GetConnection(settings.Provider.ServerSelection)
|
connection, err = providerConf.GetConnection(settings.Provider.ServerSelection)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
lines = providerConf.BuildConf(connection, settings.OpenVPN)
|
lines = providerConf.BuildConf(connection, settings.OpenVPN)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
lines, connection, err = custom.BuildConfig(settings.OpenVPN)
|
lines, connection, netInterface, err = custom.BuildConfig(settings.OpenVPN)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", fmt.Errorf("%w: %s", errBuildConfig, err)
|
return nil, "", fmt.Errorf("%w: %s", errBuildConfig, err)
|
||||||
@@ -53,7 +55,7 @@ func setupOpenVPN(ctx context.Context, fw firewall.VPNConnectionSetter,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := fw.SetVPNConnection(ctx, connection); err != nil {
|
if err := fw.SetVPNConnection(ctx, connection, netInterface); err != nil {
|
||||||
return nil, "", fmt.Errorf("%w: %s", errFirewall, err)
|
return nil, "", fmt.Errorf("%w: %s", errFirewall, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,9 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/constants"
|
|
||||||
"github.com/qdm12/gluetun/internal/portforward"
|
"github.com/qdm12/gluetun/internal/portforward"
|
||||||
"github.com/qdm12/gluetun/internal/provider"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -16,24 +14,23 @@ var (
|
|||||||
errStartPortForwarding = errors.New("cannot start port forwarding")
|
errStartPortForwarding = errors.New("cannot start port forwarding")
|
||||||
)
|
)
|
||||||
|
|
||||||
func (l *Loop) startPortForwarding(ctx context.Context, enabled bool,
|
func (l *Loop) startPortForwarding(ctx context.Context, data tunnelUpData) (err error) {
|
||||||
portForwarder provider.PortForwarder, serverName string) (err error) {
|
if !data.portForwarding {
|
||||||
if !enabled {
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// only used for PIA for now
|
// only used for PIA for now
|
||||||
gateway, err := l.routing.VPNLocalGatewayIP()
|
gateway, err := l.routing.VPNLocalGatewayIP(data.vpnIntf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %s", errObtainVPNLocalGateway, err)
|
return fmt.Errorf("%w: %s", errObtainVPNLocalGateway, err)
|
||||||
}
|
}
|
||||||
l.logger.Info("VPN gateway IP address: " + gateway.String())
|
l.logger.Info("VPN gateway IP address: " + gateway.String())
|
||||||
|
|
||||||
pfData := portforward.StartData{
|
pfData := portforward.StartData{
|
||||||
PortForwarder: portForwarder,
|
PortForwarder: data.portForwarder,
|
||||||
Gateway: gateway,
|
Gateway: gateway,
|
||||||
ServerName: serverName,
|
ServerName: data.serverName,
|
||||||
Interface: constants.TUN,
|
Interface: data.vpnIntf,
|
||||||
}
|
}
|
||||||
_, err = l.portForward.Start(ctx, pfData)
|
_, err = l.portForward.Start(ctx, pfData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
type tunnelUpData struct {
|
type tunnelUpData struct {
|
||||||
// Port forwarding
|
// Port forwarding
|
||||||
portForwarding bool
|
portForwarding bool
|
||||||
|
vpnIntf string
|
||||||
serverName string
|
serverName string
|
||||||
portForwarder provider.PortForwarder
|
portForwarder provider.PortForwarder
|
||||||
}
|
}
|
||||||
@@ -39,7 +40,7 @@ func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = l.startPortForwarding(ctx, data.portForwarding, data.portForwarder, data.serverName)
|
err = l.startPortForwarding(ctx, data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.logger.Error(err.Error())
|
l.logger.Error(err.Error())
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user