diff --git a/internal/openvpn/custom/custom.go b/internal/openvpn/custom/custom.go index 6514b6c8..d14229d2 100644 --- a/internal/openvpn/custom/custom.go +++ b/internal/openvpn/custom/custom.go @@ -3,237 +3,29 @@ package custom import ( "errors" "fmt" - "io" - "net" - "os" - "strconv" - "strings" "github.com/qdm12/gluetun/internal/configuration" - "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/models" - "github.com/qdm12/gluetun/internal/provider/utils" ) var ( - errReadCustomConfig = errors.New("cannot read custom configuration file") - errExtractConnection = errors.New("cannot extract connection from custom configuration file") + ErrReadCustomConfig = errors.New("cannot read custom configuration file") + ErrExtractConnection = errors.New("cannot extract connection from custom configuration file") ) -func ProcessCustomConfig(settings configuration.OpenVPN) ( +func BuildConfig(settings configuration.OpenVPN) ( lines []string, connection models.OpenVPNConnection, err error) { lines, err = readCustomConfigLines(settings.Config) if err != nil { - return nil, connection, fmt.Errorf("%w: %s", errReadCustomConfig, err) + return nil, connection, fmt.Errorf("%w: %s", ErrReadCustomConfig, err) } connection, err = extractConnectionFromLines(lines) if err != nil { - return nil, connection, fmt.Errorf("%w: %s", errExtractConnection, err) + return nil, connection, fmt.Errorf("%w: %s", ErrExtractConnection, err) } lines = modifyCustomConfig(lines, settings, connection) return lines, connection, nil } - -func readCustomConfigLines(filepath string) ( - lines []string, err error) { - file, err := os.Open(filepath) - if err != nil { - return nil, err - } - - b, err := io.ReadAll(file) - if err != nil { - _ = file.Close() - return nil, err - } - - if err := file.Close(); err != nil { - return nil, err - } - - return strings.Split(string(b), "\n"), nil -} - -func modifyCustomConfig(lines []string, settings configuration.OpenVPN, - connection models.OpenVPNConnection) (modified []string) { - // Remove some lines - for _, line := range lines { - switch { - case strings.HasPrefix(line, "up "), - strings.HasPrefix(line, "down "), - strings.HasPrefix(line, "verb "), - strings.HasPrefix(line, "auth-user-pass "), - strings.HasPrefix(line, "user "), - strings.HasPrefix(line, "proto "), - strings.HasPrefix(line, "remote "), - settings.Cipher != "" && strings.HasPrefix(line, "cipher "), - settings.Cipher != "" && strings.HasPrefix(line, "data-ciphers "), - settings.Auth != "" && strings.HasPrefix(line, "auth "), - settings.MSSFix > 0 && strings.HasPrefix(line, "mssfix "), - !settings.IPv6 && strings.HasPrefix(line, "tun-ipv6"): - default: - modified = append(modified, line) - } - } - - // Add values - modified = append(modified, connection.ProtoLine()) - modified = append(modified, connection.RemoteLine()) - modified = append(modified, "mute-replay-warnings") - modified = append(modified, "auth-nocache") - modified = append(modified, "pull-filter ignore \"auth-token\"") // prevent auth failed loop - modified = append(modified, "auth-retry nointeract") - modified = append(modified, "suppress-timestamps") - if settings.User != "" { - modified = append(modified, "auth-user-pass "+constants.OpenVPNAuthConf) - } - modified = append(modified, "verb "+strconv.Itoa(settings.Verbosity)) - if settings.Cipher != "" { - modified = append(modified, utils.CipherLines(settings.Cipher, settings.Version)...) - } - if settings.Auth != "" { - modified = append(modified, "auth "+settings.Auth) - } - if settings.MSSFix > 0 { - modified = append(modified, "mssfix "+strconv.Itoa(int(settings.MSSFix))) - } - if !settings.IPv6 { - modified = append(modified, `pull-filter ignore "route-ipv6"`) - modified = append(modified, `pull-filter ignore "ifconfig-ipv6"`) - } - if !settings.Root { - modified = append(modified, "user "+settings.ProcUser) - } - - return modified -} - -var ( - errRemoteLineNotFound = errors.New("remote line not found") -) - -// extractConnectionFromLines always takes the first remote line only. -func extractConnectionFromLines(lines []string) ( - connection models.OpenVPNConnection, err error) { - for i, line := range lines { - newConnectionData, err := extractConnectionFromLine(line) - if err != nil { - return connection, fmt.Errorf("on line %d: %w", i+1, err) - } - connection.UpdateEmptyWith(newConnectionData) - - if connection.Protocol != "" && connection.IP != nil { - break - } - } - - if connection.IP == nil { - return connection, errRemoteLineNotFound - } - - if connection.Protocol == "" { - connection.Protocol = constants.UDP - } - - if connection.Port == 0 { - connection.Port = 1194 - if connection.Protocol == constants.TCP { - connection.Port = 443 - } - } - - return connection, nil -} - -var ( - errExtractProto = errors.New("failed extracting protocol from proto line") - errExtractRemote = errors.New("failed extracting protocol from remote line") -) - -func extractConnectionFromLine(line string) ( - connection models.OpenVPNConnection, err error) { - switch { - case strings.HasPrefix(line, "proto "): - connection.Protocol, err = extractProto(line) - if err != nil { - return connection, fmt.Errorf("%w: %s", errExtractProto, err) - } - - // only take the first remote line - case strings.HasPrefix(line, "remote ") && connection.IP == nil: - connection.IP, connection.Port, connection.Protocol, err = extractRemote(line) - if err != nil { - return connection, fmt.Errorf("%w: %s", errExtractRemote, err) - } - } - - return connection, nil -} - -var ( - errProtoLineFieldsCount = errors.New("proto line has not 2 fields as expected") - errProtocolNotSupported = errors.New("network protocol not supported") -) - -func extractProto(line string) (protocol string, err error) { - fields := strings.Fields(line) - if len(fields) != 2 { //nolint:gomnd - return "", fmt.Errorf("%w: %s", errProtoLineFieldsCount, line) - } - - switch fields[1] { - case "tcp", "udp": - default: - return "", fmt.Errorf("%w: %s", errProtocolNotSupported, fields[1]) - } - - return fields[1], nil -} - -var ( - errRemoteLineFieldsCount = errors.New("remote line has not 2 fields as expected") - errHostNotIP = errors.New("host is not an an IP address") - errPortNotValid = errors.New("port is not valid") -) - -func extractRemote(line string) (ip net.IP, port uint16, - protocol string, err error) { - fields := strings.Fields(line) - n := len(fields) - - if n < 2 || n > 4 { - return nil, 0, "", fmt.Errorf("%w: %s", errRemoteLineFieldsCount, line) - } - - host := fields[1] - ip = net.ParseIP(host) - if ip == nil { - return nil, 0, "", fmt.Errorf("%w: %s", errHostNotIP, host) - // TODO resolve hostname once there is an option to allow it through - // the firewall before the VPN is up. - } - - if n > 2 { //nolint:gomnd - portInt, err := strconv.Atoi(fields[2]) - if err != nil { - return nil, 0, "", fmt.Errorf("%w: %s", errPortNotValid, line) - } else if portInt < 1 || portInt > 65535 { - return nil, 0, "", fmt.Errorf("%w: not between 1 and 65535: %d", errPortNotValid, portInt) - } - port = uint16(portInt) - } - - if n > 3 { //nolint:gomnd - switch fields[3] { - case "tcp", "udp": - protocol = fields[3] - default: - return nil, 0, "", fmt.Errorf("%w: %s", errProtocolNotSupported, fields[3]) - } - } - - return ip, port, protocol, nil -} diff --git a/internal/openvpn/custom/custom_test.go b/internal/openvpn/custom/custom_test.go index 3c60365a..084dd9d5 100644 --- a/internal/openvpn/custom/custom_test.go +++ b/internal/openvpn/custom/custom_test.go @@ -1,7 +1,6 @@ package custom import ( - "errors" "net" "os" "testing" @@ -13,7 +12,7 @@ import ( "github.com/stretchr/testify/require" ) -func Test_ProcessCustomConfig(t *testing.T) { +func Test_BuildConfig(t *testing.T) { t.Parallel() file, err := os.CreateTemp("", "") @@ -33,7 +32,7 @@ func Test_ProcessCustomConfig(t *testing.T) { Config: file.Name(), } - lines, connection, err := ProcessCustomConfig(settings) + lines, connection, err := BuildConfig(settings) assert.NoError(t, err) expectedLines := []string{ @@ -62,348 +61,3 @@ func Test_ProcessCustomConfig(t *testing.T) { } assert.Equal(t, expectedConnection, connection) } - -func Test_readCustomConfigLines(t *testing.T) { - t.Parallel() - - file, err := os.CreateTemp("", "") - require.NoError(t, err) - defer removeFile(t, file.Name()) - defer file.Close() - - _, err = file.WriteString("line one\nline two\nline three\n") - require.NoError(t, err) - - err = file.Close() - require.NoError(t, err) - - lines, err := readCustomConfigLines(file.Name()) - assert.NoError(t, err) - - expectedLines := []string{ - "line one", "line two", "line three", "", - } - assert.Equal(t, expectedLines, lines) -} - -func removeFile(t *testing.T, filename string) { - t.Helper() - err := os.RemoveAll(filename) - require.NoError(t, err) -} - -func Test_modifyCustomConfig(t *testing.T) { - t.Parallel() - - testCases := map[string]struct { - lines []string - settings configuration.OpenVPN - connection models.OpenVPNConnection - modified []string - }{ - "mixed": { - lines: []string{ - "up bla", - "proto tcp", - "remote 5.5.5.5", - "cipher bla", - "tun-ipv6", - "keep me here", - "auth bla", - }, - settings: configuration.OpenVPN{ - User: "user", - Cipher: "cipher", - Auth: "auth", - MSSFix: 1000, - ProcUser: "procuser", - }, - connection: models.OpenVPNConnection{ - IP: net.IPv4(1, 2, 3, 4), - Port: 1194, - Protocol: constants.UDP, - }, - modified: []string{ - "keep me here", - "proto udp", - "remote 1.2.3.4 1194", - "mute-replay-warnings", - "auth-nocache", - "pull-filter ignore \"auth-token\"", - "auth-retry nointeract", - "suppress-timestamps", - "auth-user-pass /etc/openvpn/auth.conf", - "verb 0", - "data-ciphers-fallback cipher", - "data-ciphers cipher", - "auth auth", - "mssfix 1000", - "pull-filter ignore \"route-ipv6\"", - "pull-filter ignore \"ifconfig-ipv6\"", - "user procuser", - }, - }, - } - - for name, testCase := range testCases { - testCase := testCase - t.Run(name, func(t *testing.T) { - t.Parallel() - - modified := modifyCustomConfig(testCase.lines, - testCase.settings, testCase.connection) - - assert.Equal(t, testCase.modified, modified) - }) - } -} - -func Test_extractConnectionFromLines(t *testing.T) { - t.Parallel() - - testCases := map[string]struct { - lines []string - connection models.OpenVPNConnection - err error - }{ - "success": { - lines: []string{"bla bla", "proto tcp", "remote 1.2.3.4 1194 tcp"}, - connection: models.OpenVPNConnection{ - IP: net.IPv4(1, 2, 3, 4), - Port: 1194, - Protocol: constants.TCP, - }, - }, - "extraction error": { - lines: []string{"bla bla", "proto bad", "remote 1.2.3.4 1194 tcp"}, - err: errors.New("on line 2: failed extracting protocol from proto line: network protocol not supported: bad"), - }, - "only use first values found": { - lines: []string{"proto udp", "proto tcp", "remote 1.2.3.4 443 tcp", "remote 5.2.3.4 1194 udp"}, - connection: models.OpenVPNConnection{ - IP: net.IPv4(1, 2, 3, 4), - Port: 443, - Protocol: constants.UDP, - }, - }, - "no IP found": { - lines: []string{"proto tcp"}, - connection: models.OpenVPNConnection{ - Protocol: constants.TCP, - }, - err: errRemoteLineNotFound, - }, - "default TCP port": { - lines: []string{"remote 1.2.3.4", "proto tcp"}, - connection: models.OpenVPNConnection{ - IP: net.IPv4(1, 2, 3, 4), - Port: 443, - Protocol: constants.TCP, - }, - }, - "default UDP port": { - lines: []string{"remote 1.2.3.4", "proto udp"}, - connection: models.OpenVPNConnection{ - IP: net.IPv4(1, 2, 3, 4), - Port: 1194, - Protocol: constants.UDP, - }, - }, - } - - for name, testCase := range testCases { - testCase := testCase - t.Run(name, func(t *testing.T) { - t.Parallel() - - connection, err := extractConnectionFromLines(testCase.lines) - - if testCase.err != nil { - require.Error(t, err) - assert.Equal(t, testCase.err.Error(), err.Error()) - } else { - assert.NoError(t, err) - } - - assert.Equal(t, testCase.connection, connection) - }) - } -} - -func Test_extractConnectionFromLine(t *testing.T) { - t.Parallel() - - testCases := map[string]struct { - line string - connection models.OpenVPNConnection - isErr error - }{ - "irrelevant line": { - line: "bla bla", - }, - "extract proto error": { - line: "proto bad", - isErr: errExtractProto, - }, - "extract proto success": { - line: "proto tcp", - connection: models.OpenVPNConnection{ - Protocol: constants.TCP, - }, - }, - "extract remote error": { - line: "remote bad", - isErr: errExtractRemote, - }, - "extract remote success": { - line: "remote 1.2.3.4 1194 udp", - connection: models.OpenVPNConnection{ - IP: net.IPv4(1, 2, 3, 4), - Port: 1194, - Protocol: constants.UDP, - }, - }, - } - - for name, testCase := range testCases { - testCase := testCase - t.Run(name, func(t *testing.T) { - t.Parallel() - - connection, err := extractConnectionFromLine(testCase.line) - - if testCase.isErr != nil { - assert.ErrorIs(t, err, testCase.isErr) - } else { - assert.NoError(t, err) - } - - assert.Equal(t, testCase.connection, connection) - }) - } -} - -func Test_extractProto(t *testing.T) { - t.Parallel() - - testCases := map[string]struct { - line string - protocol string - err error - }{ - "fields error": { - line: "proto one two", - err: errors.New("proto line has not 2 fields as expected: proto one two"), - }, - "bad protocol": { - line: "proto bad", - err: errors.New("network protocol not supported: bad"), - }, - "udp": { - line: "proto udp", - protocol: constants.UDP, - }, - "tcp": { - line: "proto tcp", - protocol: constants.TCP, - }, - } - - for name, testCase := range testCases { - testCase := testCase - t.Run(name, func(t *testing.T) { - t.Parallel() - - protocol, err := extractProto(testCase.line) - - if testCase.err != nil { - require.Error(t, err) - assert.Equal(t, testCase.err.Error(), err.Error()) - } else { - assert.NoError(t, err) - } - - assert.Equal(t, testCase.protocol, protocol) - }) - } -} - -func Test_extractRemote(t *testing.T) { - t.Parallel() - - testCases := map[string]struct { - line string - ip net.IP - port uint16 - protocol string - err error - }{ - "not enough fields": { - line: "remote", - err: errors.New("remote line has not 2 fields as expected: remote"), - }, - "too many fields": { - line: "remote one two three four", - err: errors.New("remote line has not 2 fields as expected: remote one two three four"), - }, - "host is not an IP": { - line: "remote somehost.com", - err: errors.New("host is not an an IP address: somehost.com"), - }, - "only IP host": { - line: "remote 1.2.3.4", - ip: net.IPv4(1, 2, 3, 4), - }, - "port not an integer": { - line: "remote 1.2.3.4 bad", - err: errors.New("port is not valid: remote 1.2.3.4 bad"), - }, - "port is zero": { - line: "remote 1.2.3.4 0", - err: errors.New("port is not valid: not between 1 and 65535: 0"), - }, - "port is minus one": { - line: "remote 1.2.3.4 -1", - err: errors.New("port is not valid: not between 1 and 65535: -1"), - }, - "port is over 65535": { - line: "remote 1.2.3.4 65536", - err: errors.New("port is not valid: not between 1 and 65535: 65536"), - }, - "IP host and port": { - line: "remote 1.2.3.4 8000", - ip: net.IPv4(1, 2, 3, 4), - port: 8000, - }, - "invalid protocol": { - line: "remote 1.2.3.4 8000 bad", - err: errors.New("network protocol not supported: bad"), - }, - "IP host and port and protocol": { - line: "remote 1.2.3.4 8000 udp", - ip: net.IPv4(1, 2, 3, 4), - port: 8000, - protocol: constants.UDP, - }, - } - - for name, testCase := range testCases { - testCase := testCase - t.Run(name, func(t *testing.T) { - t.Parallel() - - ip, port, protocol, err := extractRemote(testCase.line) - - if testCase.err != nil { - require.Error(t, err) - assert.Equal(t, testCase.err.Error(), err.Error()) - } else { - assert.NoError(t, err) - } - - assert.Equal(t, testCase.ip, ip) - assert.Equal(t, testCase.port, port) - assert.Equal(t, testCase.protocol, protocol) - }) - } -} diff --git a/internal/openvpn/custom/extract.go b/internal/openvpn/custom/extract.go new file mode 100644 index 00000000..80defbf0 --- /dev/null +++ b/internal/openvpn/custom/extract.go @@ -0,0 +1,139 @@ +package custom + +import ( + "errors" + "fmt" + "net" + "strconv" + "strings" + + "github.com/qdm12/gluetun/internal/constants" + "github.com/qdm12/gluetun/internal/models" +) + +var ( + errRemoteLineNotFound = errors.New("remote line not found") +) + +// extractConnectionFromLines always takes the first remote line only. +func extractConnectionFromLines(lines []string) ( + connection models.OpenVPNConnection, err error) { + for i, line := range lines { + newConnectionData, err := extractConnectionFromLine(line) + if err != nil { + return connection, fmt.Errorf("on line %d: %w", i+1, err) + } + connection.UpdateEmptyWith(newConnectionData) + + if connection.Protocol != "" && connection.IP != nil { + break + } + } + + if connection.IP == nil { + return connection, errRemoteLineNotFound + } + + if connection.Protocol == "" { + connection.Protocol = constants.UDP + } + + if connection.Port == 0 { + connection.Port = 1194 + if connection.Protocol == constants.TCP { + connection.Port = 443 + } + } + + return connection, nil +} + +var ( + errExtractProto = errors.New("failed extracting protocol from proto line") + errExtractRemote = errors.New("failed extracting protocol from remote line") +) + +func extractConnectionFromLine(line string) ( + connection models.OpenVPNConnection, err error) { + switch { + case strings.HasPrefix(line, "proto "): + connection.Protocol, err = extractProto(line) + if err != nil { + return connection, fmt.Errorf("%w: %s", errExtractProto, err) + } + + // only take the first remote line + case strings.HasPrefix(line, "remote ") && connection.IP == nil: + connection.IP, connection.Port, connection.Protocol, err = extractRemote(line) + if err != nil { + return connection, fmt.Errorf("%w: %s", errExtractRemote, err) + } + } + + return connection, nil +} + +var ( + errProtoLineFieldsCount = errors.New("proto line has not 2 fields as expected") + errProtocolNotSupported = errors.New("network protocol not supported") +) + +func extractProto(line string) (protocol string, err error) { + fields := strings.Fields(line) + if len(fields) != 2 { //nolint:gomnd + return "", fmt.Errorf("%w: %s", errProtoLineFieldsCount, line) + } + + switch fields[1] { + case "tcp", "udp": + default: + return "", fmt.Errorf("%w: %s", errProtocolNotSupported, fields[1]) + } + + return fields[1], nil +} + +var ( + errRemoteLineFieldsCount = errors.New("remote line has not 2 fields as expected") + errHostNotIP = errors.New("host is not an an IP address") + errPortNotValid = errors.New("port is not valid") +) + +func extractRemote(line string) (ip net.IP, port uint16, + protocol string, err error) { + fields := strings.Fields(line) + n := len(fields) + + if n < 2 || n > 4 { + return nil, 0, "", fmt.Errorf("%w: %s", errRemoteLineFieldsCount, line) + } + + host := fields[1] + ip = net.ParseIP(host) + if ip == nil { + return nil, 0, "", fmt.Errorf("%w: %s", errHostNotIP, host) + // TODO resolve hostname once there is an option to allow it through + // the firewall before the VPN is up. + } + + if n > 2 { //nolint:gomnd + portInt, err := strconv.Atoi(fields[2]) + if err != nil { + return nil, 0, "", fmt.Errorf("%w: %s", errPortNotValid, line) + } else if portInt < 1 || portInt > 65535 { + return nil, 0, "", fmt.Errorf("%w: not between 1 and 65535: %d", errPortNotValid, portInt) + } + port = uint16(portInt) + } + + if n > 3 { //nolint:gomnd + switch fields[3] { + case "tcp", "udp": + protocol = fields[3] + default: + return nil, 0, "", fmt.Errorf("%w: %s", errProtocolNotSupported, fields[3]) + } + } + + return ip, port, protocol, nil +} diff --git a/internal/openvpn/custom/extract_test.go b/internal/openvpn/custom/extract_test.go new file mode 100644 index 00000000..e0f7fc57 --- /dev/null +++ b/internal/openvpn/custom/extract_test.go @@ -0,0 +1,262 @@ +package custom + +import ( + "errors" + "net" + "testing" + + "github.com/qdm12/gluetun/internal/constants" + "github.com/qdm12/gluetun/internal/models" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_extractConnectionFromLines(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + lines []string + connection models.OpenVPNConnection + err error + }{ + "success": { + lines: []string{"bla bla", "proto tcp", "remote 1.2.3.4 1194 tcp"}, + connection: models.OpenVPNConnection{ + IP: net.IPv4(1, 2, 3, 4), + Port: 1194, + Protocol: constants.TCP, + }, + }, + "extraction error": { + lines: []string{"bla bla", "proto bad", "remote 1.2.3.4 1194 tcp"}, + err: errors.New("on line 2: failed extracting protocol from proto line: network protocol not supported: bad"), + }, + "only use first values found": { + lines: []string{"proto udp", "proto tcp", "remote 1.2.3.4 443 tcp", "remote 5.2.3.4 1194 udp"}, + connection: models.OpenVPNConnection{ + IP: net.IPv4(1, 2, 3, 4), + Port: 443, + Protocol: constants.UDP, + }, + }, + "no IP found": { + lines: []string{"proto tcp"}, + connection: models.OpenVPNConnection{ + Protocol: constants.TCP, + }, + err: errRemoteLineNotFound, + }, + "default TCP port": { + lines: []string{"remote 1.2.3.4", "proto tcp"}, + connection: models.OpenVPNConnection{ + IP: net.IPv4(1, 2, 3, 4), + Port: 443, + Protocol: constants.TCP, + }, + }, + "default UDP port": { + lines: []string{"remote 1.2.3.4", "proto udp"}, + connection: models.OpenVPNConnection{ + IP: net.IPv4(1, 2, 3, 4), + Port: 1194, + Protocol: constants.UDP, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + connection, err := extractConnectionFromLines(testCase.lines) + + if testCase.err != nil { + require.Error(t, err) + assert.Equal(t, testCase.err.Error(), err.Error()) + } else { + assert.NoError(t, err) + } + + assert.Equal(t, testCase.connection, connection) + }) + } +} + +func Test_extractConnectionFromLine(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + line string + connection models.OpenVPNConnection + isErr error + }{ + "irrelevant line": { + line: "bla bla", + }, + "extract proto error": { + line: "proto bad", + isErr: errExtractProto, + }, + "extract proto success": { + line: "proto tcp", + connection: models.OpenVPNConnection{ + Protocol: constants.TCP, + }, + }, + "extract remote error": { + line: "remote bad", + isErr: errExtractRemote, + }, + "extract remote success": { + line: "remote 1.2.3.4 1194 udp", + connection: models.OpenVPNConnection{ + IP: net.IPv4(1, 2, 3, 4), + Port: 1194, + Protocol: constants.UDP, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + connection, err := extractConnectionFromLine(testCase.line) + + if testCase.isErr != nil { + assert.ErrorIs(t, err, testCase.isErr) + } else { + assert.NoError(t, err) + } + + assert.Equal(t, testCase.connection, connection) + }) + } +} + +func Test_extractProto(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + line string + protocol string + err error + }{ + "fields error": { + line: "proto one two", + err: errors.New("proto line has not 2 fields as expected: proto one two"), + }, + "bad protocol": { + line: "proto bad", + err: errors.New("network protocol not supported: bad"), + }, + "udp": { + line: "proto udp", + protocol: constants.UDP, + }, + "tcp": { + line: "proto tcp", + protocol: constants.TCP, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + protocol, err := extractProto(testCase.line) + + if testCase.err != nil { + require.Error(t, err) + assert.Equal(t, testCase.err.Error(), err.Error()) + } else { + assert.NoError(t, err) + } + + assert.Equal(t, testCase.protocol, protocol) + }) + } +} + +func Test_extractRemote(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + line string + ip net.IP + port uint16 + protocol string + err error + }{ + "not enough fields": { + line: "remote", + err: errors.New("remote line has not 2 fields as expected: remote"), + }, + "too many fields": { + line: "remote one two three four", + err: errors.New("remote line has not 2 fields as expected: remote one two three four"), + }, + "host is not an IP": { + line: "remote somehost.com", + err: errors.New("host is not an an IP address: somehost.com"), + }, + "only IP host": { + line: "remote 1.2.3.4", + ip: net.IPv4(1, 2, 3, 4), + }, + "port not an integer": { + line: "remote 1.2.3.4 bad", + err: errors.New("port is not valid: remote 1.2.3.4 bad"), + }, + "port is zero": { + line: "remote 1.2.3.4 0", + err: errors.New("port is not valid: not between 1 and 65535: 0"), + }, + "port is minus one": { + line: "remote 1.2.3.4 -1", + err: errors.New("port is not valid: not between 1 and 65535: -1"), + }, + "port is over 65535": { + line: "remote 1.2.3.4 65536", + err: errors.New("port is not valid: not between 1 and 65535: 65536"), + }, + "IP host and port": { + line: "remote 1.2.3.4 8000", + ip: net.IPv4(1, 2, 3, 4), + port: 8000, + }, + "invalid protocol": { + line: "remote 1.2.3.4 8000 bad", + err: errors.New("network protocol not supported: bad"), + }, + "IP host and port and protocol": { + line: "remote 1.2.3.4 8000 udp", + ip: net.IPv4(1, 2, 3, 4), + port: 8000, + protocol: constants.UDP, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + ip, port, protocol, err := extractRemote(testCase.line) + + if testCase.err != nil { + require.Error(t, err) + assert.Equal(t, testCase.err.Error(), err.Error()) + } else { + assert.NoError(t, err) + } + + assert.Equal(t, testCase.ip, ip) + assert.Equal(t, testCase.port, port) + assert.Equal(t, testCase.protocol, protocol) + }) + } +} diff --git a/internal/openvpn/custom/helpers_test.go b/internal/openvpn/custom/helpers_test.go new file mode 100644 index 00000000..b6f6df2e --- /dev/null +++ b/internal/openvpn/custom/helpers_test.go @@ -0,0 +1,14 @@ +package custom + +import ( + "os" + "testing" + + "github.com/stretchr/testify/require" +) + +func removeFile(t *testing.T, filename string) { + t.Helper() + err := os.RemoveAll(filename) + require.NoError(t, err) +} diff --git a/internal/openvpn/custom/modify.go b/internal/openvpn/custom/modify.go new file mode 100644 index 00000000..66da7a93 --- /dev/null +++ b/internal/openvpn/custom/modify.go @@ -0,0 +1,65 @@ +package custom + +import ( + "strconv" + "strings" + + "github.com/qdm12/gluetun/internal/configuration" + "github.com/qdm12/gluetun/internal/constants" + "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/provider/utils" +) + +func modifyCustomConfig(lines []string, settings configuration.OpenVPN, + connection models.OpenVPNConnection) (modified []string) { + // Remove some lines + for _, line := range lines { + switch { + case strings.HasPrefix(line, "up "), + strings.HasPrefix(line, "down "), + strings.HasPrefix(line, "verb "), + strings.HasPrefix(line, "auth-user-pass "), + strings.HasPrefix(line, "user "), + strings.HasPrefix(line, "proto "), + strings.HasPrefix(line, "remote "), + settings.Cipher != "" && strings.HasPrefix(line, "cipher "), + settings.Cipher != "" && strings.HasPrefix(line, "data-ciphers "), + settings.Auth != "" && strings.HasPrefix(line, "auth "), + settings.MSSFix > 0 && strings.HasPrefix(line, "mssfix "), + !settings.IPv6 && strings.HasPrefix(line, "tun-ipv6"): + default: + modified = append(modified, line) + } + } + + // Add values + modified = append(modified, connection.ProtoLine()) + modified = append(modified, connection.RemoteLine()) + modified = append(modified, "mute-replay-warnings") + modified = append(modified, "auth-nocache") + modified = append(modified, "pull-filter ignore \"auth-token\"") // prevent auth failed loop + modified = append(modified, "auth-retry nointeract") + modified = append(modified, "suppress-timestamps") + if settings.User != "" { + modified = append(modified, "auth-user-pass "+constants.OpenVPNAuthConf) + } + modified = append(modified, "verb "+strconv.Itoa(settings.Verbosity)) + if settings.Cipher != "" { + modified = append(modified, utils.CipherLines(settings.Cipher, settings.Version)...) + } + if settings.Auth != "" { + modified = append(modified, "auth "+settings.Auth) + } + if settings.MSSFix > 0 { + modified = append(modified, "mssfix "+strconv.Itoa(int(settings.MSSFix))) + } + if !settings.IPv6 { + modified = append(modified, `pull-filter ignore "route-ipv6"`) + modified = append(modified, `pull-filter ignore "ifconfig-ipv6"`) + } + if !settings.Root { + modified = append(modified, "user "+settings.ProcUser) + } + + return modified +} diff --git a/internal/openvpn/custom/modify_test.go b/internal/openvpn/custom/modify_test.go new file mode 100644 index 00000000..e21aaca8 --- /dev/null +++ b/internal/openvpn/custom/modify_test.go @@ -0,0 +1,77 @@ +package custom + +import ( + "net" + "testing" + + "github.com/qdm12/gluetun/internal/configuration" + "github.com/qdm12/gluetun/internal/constants" + "github.com/qdm12/gluetun/internal/models" + "github.com/stretchr/testify/assert" +) + +func Test_modifyCustomConfig(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + lines []string + settings configuration.OpenVPN + connection models.OpenVPNConnection + modified []string + }{ + "mixed": { + lines: []string{ + "up bla", + "proto tcp", + "remote 5.5.5.5", + "cipher bla", + "tun-ipv6", + "keep me here", + "auth bla", + }, + settings: configuration.OpenVPN{ + User: "user", + Cipher: "cipher", + Auth: "auth", + MSSFix: 1000, + ProcUser: "procuser", + }, + connection: models.OpenVPNConnection{ + IP: net.IPv4(1, 2, 3, 4), + Port: 1194, + Protocol: constants.UDP, + }, + modified: []string{ + "keep me here", + "proto udp", + "remote 1.2.3.4 1194", + "mute-replay-warnings", + "auth-nocache", + "pull-filter ignore \"auth-token\"", + "auth-retry nointeract", + "suppress-timestamps", + "auth-user-pass /etc/openvpn/auth.conf", + "verb 0", + "data-ciphers-fallback cipher", + "data-ciphers cipher", + "auth auth", + "mssfix 1000", + "pull-filter ignore \"route-ipv6\"", + "pull-filter ignore \"ifconfig-ipv6\"", + "user procuser", + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + modified := modifyCustomConfig(testCase.lines, + testCase.settings, testCase.connection) + + assert.Equal(t, testCase.modified, modified) + }) + } +} diff --git a/internal/openvpn/custom/read.go b/internal/openvpn/custom/read.go new file mode 100644 index 00000000..8e2ee03d --- /dev/null +++ b/internal/openvpn/custom/read.go @@ -0,0 +1,27 @@ +package custom + +import ( + "io" + "os" + "strings" +) + +func readCustomConfigLines(filepath string) ( + lines []string, err error) { + file, err := os.Open(filepath) + if err != nil { + return nil, err + } + + b, err := io.ReadAll(file) + if err != nil { + _ = file.Close() + return nil, err + } + + if err := file.Close(); err != nil { + return nil, err + } + + return strings.Split(string(b), "\n"), nil +} diff --git a/internal/openvpn/custom/read_test.go b/internal/openvpn/custom/read_test.go new file mode 100644 index 00000000..bc050427 --- /dev/null +++ b/internal/openvpn/custom/read_test.go @@ -0,0 +1,32 @@ +package custom + +import ( + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_readCustomConfigLines(t *testing.T) { + t.Parallel() + + file, err := os.CreateTemp("", "") + require.NoError(t, err) + defer removeFile(t, file.Name()) + defer file.Close() + + _, err = file.WriteString("line one\nline two\nline three\n") + require.NoError(t, err) + + err = file.Close() + require.NoError(t, err) + + lines, err := readCustomConfigLines(file.Name()) + assert.NoError(t, err) + + expectedLines := []string{ + "line one", "line two", "line three", "", + } + assert.Equal(t, expectedLines, lines) +} diff --git a/internal/openvpn/run.go b/internal/openvpn/run.go index c749ca58..a028c9c9 100644 --- a/internal/openvpn/run.go +++ b/internal/openvpn/run.go @@ -37,7 +37,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) { lines = providerConf.BuildConf(connection, openVPNSettings) } } else { - lines, connection, err = custom.ProcessCustomConfig(openVPNSettings) + lines, connection, err = custom.BuildConfig(openVPNSettings) } if err != nil { l.crashed(ctx, err)