Compare commits

..

1 Commits

Author SHA1 Message Date
Quentin McGaw
0717578b06 change!(server): auth is now required for all routes 2025-11-14 21:30:42 +00:00
50 changed files with 4031 additions and 7540 deletions

View File

@@ -93,9 +93,6 @@ jobs:
- name: Run Gluetun container with Mullvad configuration - name: Run Gluetun container with Mullvad configuration
run: echo -e "${{ secrets.MULLVAD_WIREGUARD_PRIVATE_KEY }}\n${{ secrets.MULLVAD_WIREGUARD_ADDRESS }}" | ./ci/runner mullvad run: echo -e "${{ secrets.MULLVAD_WIREGUARD_PRIVATE_KEY }}\n${{ secrets.MULLVAD_WIREGUARD_ADDRESS }}" | ./ci/runner mullvad
- name: Run Gluetun container with ProtonVPN configuration
run: echo -e "${{ secrets.PROTONVPN_WIREGUARD_PRIVATE_KEY }}" | ./ci/runner protonvpn
codeql: codeql:
runs-on: ubuntu-latest runs-on: ubuntu-latest
permissions: permissions:
@@ -121,7 +118,7 @@ jobs:
github.event_name == 'release' || github.event_name == 'release' ||
(github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository && github.actor != 'dependabot[bot]') (github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository && github.actor != 'dependabot[bot]')
) )
needs: [verify, verify-private, codeql] needs: [verify, codeql]
permissions: permissions:
actions: read actions: read
contents: read contents: read

View File

@@ -20,7 +20,7 @@ jobs:
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v5
- uses: DavidAnson/markdownlint-cli2-action@v21 - uses: DavidAnson/markdownlint-cli2-action@v20
with: with:
globs: "**.md" globs: "**.md"
config: .markdownlint-cli2.jsonc config: .markdownlint-cli2.jsonc

View File

@@ -163,9 +163,8 @@ ENV VPN_SERVICE_PROVIDER=pia \
LOG_LEVEL=info \ LOG_LEVEL=info \
# Health # Health
HEALTH_SERVER_ADDRESS=127.0.0.1:9999 \ HEALTH_SERVER_ADDRESS=127.0.0.1:9999 \
HEALTH_TARGET_ADDRESSES=cloudflare.com:443,github.com:443 \ HEALTH_TARGET_ADDRESS=cloudflare.com:443 \
HEALTH_ICMP_TARGET_IPS=1.1.1.1,8.8.8.8 \ HEALTH_ICMP_TARGET_IP=1.1.1.1 \
HEALTH_SMALL_CHECK_TYPE=icmp \
HEALTH_RESTART_VPN=on \ HEALTH_RESTART_VPN=on \
# DNS # DNS
DNS_SERVER=on \ DNS_SERVER=on \
@@ -208,7 +207,7 @@ ENV VPN_SERVICE_PROVIDER=pia \
UPDATER_PERIOD=0 \ UPDATER_PERIOD=0 \
UPDATER_MIN_RATIO=0.8 \ UPDATER_MIN_RATIO=0.8 \
UPDATER_VPN_SERVICE_PROVIDERS= \ UPDATER_VPN_SERVICE_PROVIDERS= \
UPDATER_PROTONVPN_EMAIL= \ UPDATER_PROTONVPN_USERNAME= \
UPDATER_PROTONVPN_PASSWORD= \ UPDATER_PROTONVPN_PASSWORD= \
# Public IP # Public IP
PUBLICIP_FILE="/tmp/gluetun/ip" \ PUBLICIP_FILE="/tmp/gluetun/ip" \

View File

@@ -1,7 +1,5 @@
# Gluetun VPN client # Gluetun VPN client
⚠️ This and [gluetun-wiki](https://github.com/qdm12/gluetun-wiki) are the only websites for Gluetun, other websites claiming to be official are scams ⚠️
Lightweight swiss-army-knife-like VPN client to multiple VPN service providers Lightweight swiss-army-knife-like VPN client to multiple VPN service providers
![Title image](https://raw.githubusercontent.com/qdm12/gluetun/master/title.svg) ![Title image](https://raw.githubusercontent.com/qdm12/gluetun/master/title.svg)

View File

@@ -21,8 +21,6 @@ func main() {
switch os.Args[1] { switch os.Args[1] {
case "mullvad": case "mullvad":
err = internal.MullvadTest(ctx) err = internal.MullvadTest(ctx)
case "protonvpn":
err = internal.ProtonVPNTest(ctx)
default: default:
err = fmt.Errorf("unknown command: %s", os.Args[1]) err = fmt.Errorf("unknown command: %s", os.Args[1])
} }

View File

@@ -1,27 +1,193 @@
package internal package internal
import ( import (
"bufio"
"context" "context"
"fmt" "fmt"
"io"
"os"
"regexp"
"strings"
"time"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/network"
"github.com/docker/docker/client"
v1 "github.com/opencontainers/image-spec/specs-go/v1"
) )
func MullvadTest(ctx context.Context) error { func MullvadTest(ctx context.Context) error {
expectedSecrets := []string{ secrets, err := readSecrets(ctx)
"Wireguard private key",
"Wireguard address",
}
secrets, err := readSecrets(ctx, expectedSecrets)
if err != nil { if err != nil {
return fmt.Errorf("reading secrets: %w", err) return fmt.Errorf("reading secrets: %w", err)
} }
env := []string{ const timeout = 15 * time.Second
"VPN_SERVICE_PROVIDER=mullvad", ctx, cancel := context.WithTimeout(ctx, timeout)
"VPN_TYPE=wireguard", defer cancel()
"LOG_LEVEL=debug",
"SERVER_COUNTRIES=USA", client, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation())
"WIREGUARD_PRIVATE_KEY=" + secrets[0], if err != nil {
"WIREGUARD_ADDRESSES=" + secrets[1], return fmt.Errorf("creating Docker client: %w", err)
}
defer client.Close()
config := &container.Config{
Image: "qmcgaw/gluetun",
StopTimeout: ptrTo(3),
Env: []string{
"VPN_SERVICE_PROVIDER=mullvad",
"VPN_TYPE=wireguard",
"LOG_LEVEL=debug",
"SERVER_COUNTRIES=USA",
"WIREGUARD_PRIVATE_KEY=" + secrets.mullvadWireguardPrivateKey,
"WIREGUARD_ADDRESSES=" + secrets.mullvadWireguardAddress,
},
}
hostConfig := &container.HostConfig{
AutoRemove: true,
CapAdd: []string{"NET_ADMIN", "NET_RAW"},
}
networkConfig := (*network.NetworkingConfig)(nil)
platform := (*v1.Platform)(nil)
const containerName = "" // auto-generated name
response, err := client.ContainerCreate(ctx, config, hostConfig, networkConfig, platform, containerName)
if err != nil {
return fmt.Errorf("creating container: %w", err)
}
for _, warning := range response.Warnings {
fmt.Println("Warning during container creation:", warning)
}
containerID := response.ID
defer stopContainer(client, containerID)
beforeStartTime := time.Now()
err = client.ContainerStart(ctx, containerID, container.StartOptions{})
if err != nil {
return fmt.Errorf("starting container: %w", err)
}
return waitForLogLine(ctx, client, containerID, beforeStartTime)
}
func ptrTo[T any](v T) *T { return &v }
type secrets struct {
mullvadWireguardPrivateKey string
mullvadWireguardAddress string
}
func readSecrets(ctx context.Context) (secrets, error) {
expectedSecrets := [...]string{
"Mullvad Wireguard private key",
"Mullvad Wireguard address",
}
scanner := bufio.NewScanner(os.Stdin)
lines := make([]string, 0, len(expectedSecrets))
for i := range expectedSecrets {
fmt.Println("🤫 reading", expectedSecrets[i], "from Stdin...")
if !scanner.Scan() {
break
}
lines = append(lines, strings.TrimSpace(scanner.Text()))
if ctx.Err() != nil {
return secrets{}, ctx.Err()
}
}
if err := scanner.Err(); err != nil {
return secrets{}, fmt.Errorf("reading secrets from stdin: %w", err)
}
if len(lines) < len(expectedSecrets) {
return secrets{}, fmt.Errorf("expected %d secrets via Stdin, but only received %d",
len(expectedSecrets), len(lines))
}
for i, line := range lines {
if line == "" {
return secrets{}, fmt.Errorf("secret on line %d/%d was empty", i+1, len(lines))
}
}
return secrets{
mullvadWireguardPrivateKey: lines[0],
mullvadWireguardAddress: lines[1],
}, nil
}
func stopContainer(client *client.Client, containerID string) {
const stopTimeout = 5 * time.Second // must be higher than 3s, see above [container.Config]'s StopTimeout field
stopCtx, stopCancel := context.WithTimeout(context.Background(), stopTimeout)
defer stopCancel()
err := client.ContainerStop(stopCtx, containerID, container.StopOptions{})
if err != nil {
fmt.Println("failed to stop container:", err)
}
}
var successRegexp = regexp.MustCompile(`^.+Public IP address is .+$`)
func waitForLogLine(ctx context.Context, client *client.Client, containerID string,
beforeStartTime time.Time,
) error {
logOptions := container.LogsOptions{
ShowStdout: true,
Follow: true,
Since: beforeStartTime.Format(time.RFC3339Nano),
}
reader, err := client.ContainerLogs(ctx, containerID, logOptions)
if err != nil {
return fmt.Errorf("error getting container logs: %w", err)
}
defer reader.Close()
var linesSeen []string
scanner := bufio.NewScanner(reader)
for ctx.Err() == nil {
if scanner.Scan() {
line := scanner.Text()
if len(line) > 8 { // remove Docker log prefix
line = line[8:]
}
linesSeen = append(linesSeen, line)
if successRegexp.MatchString(line) {
fmt.Println("✅ Success line logged")
return nil
}
continue
}
err := scanner.Err()
if err != nil && err != io.EOF {
logSeenLines(linesSeen)
return fmt.Errorf("reading log stream: %w", err)
}
// The scanner is either done or cannot read because of EOF
fmt.Println("The log scanner stopped")
logSeenLines(linesSeen)
// Check if the container is still running
inspect, err := client.ContainerInspect(ctx, containerID)
if err != nil {
return fmt.Errorf("inspecting container: %w", err)
}
if !inspect.State.Running {
return fmt.Errorf("container stopped unexpectedly while waiting for log line. Exit code: %d", inspect.State.ExitCode)
}
}
return ctx.Err()
}
func logSeenLines(lines []string) {
fmt.Println("Logs seen so far:")
for _, line := range lines {
fmt.Println(" " + line)
} }
return simpleTest(ctx, env)
} }

View File

@@ -1,25 +0,0 @@
package internal
import (
"context"
"fmt"
)
func ProtonVPNTest(ctx context.Context) error {
expectedSecrets := []string{
"Wireguard private key",
}
secrets, err := readSecrets(ctx, expectedSecrets)
if err != nil {
return fmt.Errorf("reading secrets: %w", err)
}
env := []string{
"VPN_SERVICE_PROVIDER=protonvpn",
"VPN_TYPE=wireguard",
"LOG_LEVEL=debug",
"SERVER_COUNTRIES=United States",
"WIREGUARD_PRIVATE_KEY=" + secrets[0],
}
return simpleTest(ctx, env)
}

View File

@@ -1,42 +0,0 @@
package internal
import (
"bufio"
"context"
"fmt"
"os"
"strings"
)
func readSecrets(ctx context.Context, expectedSecrets []string) (lines []string, err error) {
scanner := bufio.NewScanner(os.Stdin)
lines = make([]string, 0, len(expectedSecrets))
for i := range expectedSecrets {
fmt.Println("🤫 reading", expectedSecrets[i], "from Stdin...")
if !scanner.Scan() {
break
}
lines = append(lines, strings.TrimSpace(scanner.Text()))
fmt.Println("🤫 "+expectedSecrets[i], "secret read successfully")
if ctx.Err() != nil {
return nil, ctx.Err()
}
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("reading secrets from stdin: %w", err)
}
if len(lines) < len(expectedSecrets) {
return nil, fmt.Errorf("expected %d secrets via Stdin, but only received %d",
len(expectedSecrets), len(lines))
}
for i, line := range lines {
if line == "" {
return nil, fmt.Errorf("secret on line %d/%d was empty", i+1, len(lines))
}
}
return lines, nil
}

View File

@@ -1,134 +0,0 @@
package internal
import (
"bufio"
"context"
"fmt"
"io"
"regexp"
"time"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/network"
"github.com/docker/docker/client"
v1 "github.com/opencontainers/image-spec/specs-go/v1"
)
func ptrTo[T any](v T) *T { return &v }
func simpleTest(ctx context.Context, env []string) error {
const timeout = 30 * time.Second
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
client, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation())
if err != nil {
return fmt.Errorf("creating Docker client: %w", err)
}
defer client.Close()
config := &container.Config{
Image: "qmcgaw/gluetun",
StopTimeout: ptrTo(3),
Env: env,
}
hostConfig := &container.HostConfig{
AutoRemove: true,
CapAdd: []string{"NET_ADMIN", "NET_RAW"},
}
networkConfig := (*network.NetworkingConfig)(nil)
platform := (*v1.Platform)(nil)
const containerName = "" // auto-generated name
response, err := client.ContainerCreate(ctx, config, hostConfig, networkConfig, platform, containerName)
if err != nil {
return fmt.Errorf("creating container: %w", err)
}
for _, warning := range response.Warnings {
fmt.Println("Warning during container creation:", warning)
}
containerID := response.ID
defer stopContainer(client, containerID)
beforeStartTime := time.Now()
err = client.ContainerStart(ctx, containerID, container.StartOptions{})
if err != nil {
return fmt.Errorf("starting container: %w", err)
}
return waitForLogLine(ctx, client, containerID, beforeStartTime)
}
func stopContainer(client *client.Client, containerID string) {
const stopTimeout = 5 * time.Second // must be higher than 3s, see above [container.Config]'s StopTimeout field
stopCtx, stopCancel := context.WithTimeout(context.Background(), stopTimeout)
defer stopCancel()
err := client.ContainerStop(stopCtx, containerID, container.StopOptions{})
if err != nil {
fmt.Println("failed to stop container:", err)
}
}
var successRegexp = regexp.MustCompile(`^.+Public IP address is .+$`)
func waitForLogLine(ctx context.Context, client *client.Client, containerID string,
beforeStartTime time.Time,
) error {
logOptions := container.LogsOptions{
ShowStdout: true,
Follow: true,
Since: beforeStartTime.Format(time.RFC3339Nano),
}
reader, err := client.ContainerLogs(ctx, containerID, logOptions)
if err != nil {
return fmt.Errorf("error getting container logs: %w", err)
}
defer reader.Close()
var linesSeen []string
scanner := bufio.NewScanner(reader)
for ctx.Err() == nil {
if scanner.Scan() {
line := scanner.Text()
if len(line) > 8 { // remove Docker log prefix
line = line[8:]
}
linesSeen = append(linesSeen, line)
if successRegexp.MatchString(line) {
fmt.Println("✅ Success line logged")
return nil
}
continue
}
err := scanner.Err()
if err != nil && err != io.EOF {
logSeenLines(linesSeen)
return fmt.Errorf("reading log stream: %w", err)
}
// The scanner is either done or cannot read because of EOF
fmt.Println("The log scanner stopped")
logSeenLines(linesSeen)
// Check if the container is still running
inspect, err := client.ContainerInspect(ctx, containerID)
if err != nil {
return fmt.Errorf("inspecting container: %w", err)
}
if !inspect.State.Running {
return fmt.Errorf("container stopped unexpectedly while waiting for log line. Exit code: %d", inspect.State.ExitCode)
}
}
return ctx.Err()
}
func logSeenLines(lines []string) {
fmt.Println("Logs seen so far:")
for _, line := range lines {
fmt.Println(" " + line)
}
}

View File

@@ -164,8 +164,6 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
} }
} }
defer fmt.Println(gluetunLogo)
announcementExp, err := time.Parse(time.RFC3339, "2024-12-01T00:00:00Z") announcementExp, err := time.Parse(time.RFC3339, "2024-12-01T00:00:00Z")
if err != nil { if err != nil {
return err return err
@@ -177,7 +175,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
Version: buildInfo.Version, Version: buildInfo.Version,
Commit: buildInfo.Commit, Commit: buildInfo.Commit,
Created: buildInfo.Created, Created: buildInfo.Created,
Announcement: "All control server routes will become private by default after the v3.41.0 release", Announcement: "All control server routes are now private by default",
AnnounceExp: announcementExp, AnnounceExp: announcementExp,
// Sponsor information // Sponsor information
PaypalUser: "qmcgaw", PaypalUser: "qmcgaw",
@@ -602,34 +600,3 @@ type RunStarter interface {
Start(cmd *exec.Cmd) (stdoutLines, stderrLines <-chan string, Start(cmd *exec.Cmd) (stdoutLines, stderrLines <-chan string,
waitError <-chan error, err error) waitError <-chan error, err error)
} }
const gluetunLogo = ` @@@
@@@@
@@@@@@
@@@@.@@ @@@@@@@@@@
@@@@.@@@ @@@@@@@@==@@@@
@@@.@..@@ @@@@@@@=@..==@@@@
@@@@ @@@.@@.@@ @@@@@@===@@@@.=@@@
@...-@@ @@@@.@@.@@@ @@@ @@@@@@=======@@@=@@@@
@@@@@@@@ @@@.-%@.+@@@@@@@@ @@@@@%============@@@@
@@@.--@..@@@@.-@@@@@@@==============@@@@
@@@@ @@@-@--@@.@@.---@@@@@==============#@@@@@
@@@ @@@.@@-@@.@@--@@@@@===============@@@@@@
@@@@.@--@@@@@@@@@@================@@@@@@@
@@@..--@@*@@@@@@================@@@@+*@@
@@@.---@@.@@@@=================@@@@--@@
@@@-.---@@@@@@================@@@@*--@@@
@@@.:-#@@@@@@===============*@@@@.---@@
@@@.-------.@@@============@@@@@@.--@@@
@@@..--------:@@@=========@@@@@@@@.--@@@
@@@.-@@@@@@@@@@@========@@@@@ @@@.--@@
@@.@@@@===============@@@@@ @@@@@@---@@@@@@
@@@@@@@==============@@@@@@@@@@@@*@---@@@@@@@@
@@@@@@=============@@@@@ @@@...------------.*@@@
@@@@%===========@@@@@@ @@@..------@@@@.-----.-@@@
@@@@@@.=======@@@@@@ @@@.-------@@@@@@-.------=@@
@@@@@@@@@===@@@@@@ @@.------@@@@ @@@@.-----@@@
@@@==@@@=@@@@@@@ @@@.-@@@@@@@ @@@@@@@--@@
@@@@@@@@@@@@@ @@@@@@@@ @@@@@@@
@@@@@@@@ @@@@ @@@@
`

2
go.mod
View File

@@ -10,7 +10,7 @@ require (
github.com/klauspost/compress v1.18.1 github.com/klauspost/compress v1.18.1
github.com/klauspost/pgzip v1.2.6 github.com/klauspost/pgzip v1.2.6
github.com/pelletier/go-toml/v2 v2.2.4 github.com/pelletier/go-toml/v2 v2.2.4
github.com/qdm12/dns/v2 v2.0.0-rc9.0.20251123213823-54e987293e88 github.com/qdm12/dns/v2 v2.0.0-rc9.0.20251114155417-248acd28339f
github.com/qdm12/gosettings v0.4.4 github.com/qdm12/gosettings v0.4.4
github.com/qdm12/goshutdown v0.3.0 github.com/qdm12/goshutdown v0.3.0
github.com/qdm12/gosplash v0.2.0 github.com/qdm12/gosplash v0.2.0

4
go.sum
View File

@@ -69,8 +69,8 @@ github.com/prometheus/common v0.60.1 h1:FUas6GcOw66yB/73KC+BOZoFJmbo/1pojoILArPA
github.com/prometheus/common v0.60.1/go.mod h1:h0LYf1R1deLSKtD4Vdg8gy4RuOvENW2J/h19V5NADQw= github.com/prometheus/common v0.60.1/go.mod h1:h0LYf1R1deLSKtD4Vdg8gy4RuOvENW2J/h19V5NADQw=
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
github.com/qdm12/dns/v2 v2.0.0-rc9.0.20251123213823-54e987293e88 h1:GJ5FALvJ3UmHjVaNYebrfV5zF5You4dq8HfRWZy2loM= github.com/qdm12/dns/v2 v2.0.0-rc9.0.20251114155417-248acd28339f h1:6wN5D9wACfmXDsQ366egVt0jXY4nqL/QnIwg4nWhXco=
github.com/qdm12/dns/v2 v2.0.0-rc9.0.20251123213823-54e987293e88/go.mod h1:98foWgXJZ+g8gJIuO+fdO+oWpFei5WShMFTeN4Im2lE= github.com/qdm12/dns/v2 v2.0.0-rc9.0.20251114155417-248acd28339f/go.mod h1:98foWgXJZ+g8gJIuO+fdO+oWpFei5WShMFTeN4Im2lE=
github.com/qdm12/goservices v0.1.1-0.20251104135713-6bee97bd4978 h1:TRGpCU1l0lNwtogEUSs5U+RFceYxkAJUmrGabno7J5c= github.com/qdm12/goservices v0.1.1-0.20251104135713-6bee97bd4978 h1:TRGpCU1l0lNwtogEUSs5U+RFceYxkAJUmrGabno7J5c=
github.com/qdm12/goservices v0.1.1-0.20251104135713-6bee97bd4978/go.mod h1:D1Po4CRQLYjccnAR2JsVlN1sBMgQrcNLONbvyuzcdTg= github.com/qdm12/goservices v0.1.1-0.20251104135713-6bee97bd4978/go.mod h1:D1Po4CRQLYjccnAR2JsVlN1sBMgQrcNLONbvyuzcdTg=
github.com/qdm12/gosettings v0.4.4 h1:SM6tOZDf6k8qbjWU8KWyBF4mWIixfsKCfh9DGRLHlj4= github.com/qdm12/gosettings v0.4.4 h1:SM6tOZDf6k8qbjWU8KWyBF4mWIixfsKCfh9DGRLHlj4=

View File

@@ -7,4 +7,3 @@ func newNoopLogger() *noopLogger {
} }
func (l *noopLogger) Info(string) {} func (l *noopLogger) Info(string) {}
func (l *noopLogger) Warn(string) {}

View File

@@ -38,7 +38,7 @@ type UpdaterLogger interface {
func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) error { func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) error {
options := settings.Updater{} options := settings.Updater{}
var endUserMode, maintainerMode, updateAll bool var endUserMode, maintainerMode, updateAll bool
var csvProviders, ipToken, protonUsername, protonEmail, protonPassword string var csvProviders, ipToken, protonUsername, protonPassword string
flagSet := flag.NewFlagSet("update", flag.ExitOnError) flagSet := flag.NewFlagSet("update", flag.ExitOnError)
flagSet.BoolVar(&endUserMode, "enduser", false, "Write results to /gluetun/servers.json (for end users)") flagSet.BoolVar(&endUserMode, "enduser", false, "Write results to /gluetun/servers.json (for end users)")
flagSet.BoolVar(&maintainerMode, "maintainer", false, flagSet.BoolVar(&maintainerMode, "maintainer", false,
@@ -50,9 +50,7 @@ func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) e
flagSet.BoolVar(&updateAll, "all", false, "Update servers for all VPN providers") flagSet.BoolVar(&updateAll, "all", false, "Update servers for all VPN providers")
flagSet.StringVar(&csvProviders, "providers", "", "CSV string of VPN providers to update server data for") flagSet.StringVar(&csvProviders, "providers", "", "CSV string of VPN providers to update server data for")
flagSet.StringVar(&ipToken, "ip-token", "", "IP data service token (e.g. ipinfo.io) to use") flagSet.StringVar(&ipToken, "ip-token", "", "IP data service token (e.g. ipinfo.io) to use")
flagSet.StringVar(&protonUsername, "proton-username", "", flagSet.StringVar(&protonUsername, "proton-username", "", "Username to use to authenticate with Proton")
"(Retro-compatibility) Username to use to authenticate with Proton. Use -proton-email instead.") // v4 remove this
flagSet.StringVar(&protonEmail, "proton-email", "", "Email to use to authenticate with Proton")
flagSet.StringVar(&protonPassword, "proton-password", "", "Password to use to authenticate with Proton") flagSet.StringVar(&protonPassword, "proton-password", "", "Password to use to authenticate with Proton")
if err := flagSet.Parse(args); err != nil { if err := flagSet.Parse(args); err != nil {
return err return err
@@ -72,12 +70,7 @@ func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) e
} }
if slices.Contains(options.Providers, providers.Protonvpn) { if slices.Contains(options.Providers, providers.Protonvpn) {
if protonEmail == "" && protonUsername != "" { options.ProtonUsername = &protonUsername
protonEmail = protonUsername + "@protonmail.com"
logger.Warn("use -proton-email instead of -proton-username in the future. " +
"This assumes the email is " + protonEmail + " and may not work.")
}
options.ProtonEmail = &protonEmail
options.ProtonPassword = &protonPassword options.ProtonPassword = &protonPassword
} }
@@ -88,11 +81,7 @@ func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) e
return fmt.Errorf("options validation failed: %w", err) return fmt.Errorf("options validation failed: %w", err)
} }
serversDataPath := constants.ServersData storage, err := storage.New(logger, constants.ServersData)
if maintainerMode {
serversDataPath = ""
}
storage, err := storage.New(logger, serversDataPath)
if err != nil { if err != nil {
return fmt.Errorf("creating servers storage: %w", err) return fmt.Errorf("creating servers storage: %w", err)
} }

View File

@@ -37,7 +37,7 @@ var (
ErrSystemTimezoneNotValid = errors.New("timezone is not valid") ErrSystemTimezoneNotValid = errors.New("timezone is not valid")
ErrUpdaterPeriodTooSmall = errors.New("VPN server data updater period is too small") ErrUpdaterPeriodTooSmall = errors.New("VPN server data updater period is too small")
ErrUpdaterProtonPasswordMissing = errors.New("proton password is missing") ErrUpdaterProtonPasswordMissing = errors.New("proton password is missing")
ErrUpdaterProtonEmailMissing = errors.New("proton email is missing") ErrUpdaterProtonUsernameMissing = errors.New("proton username is missing")
ErrVPNProviderNameNotValid = errors.New("VPN provider name is not valid") ErrVPNProviderNameNotValid = errors.New("VPN provider name is not valid")
ErrVPNTypeNotValid = errors.New("VPN type is not valid") ErrVPNTypeNotValid = errors.New("VPN type is not valid")
ErrWireguardAllowedIPNotSet = errors.New("allowed IP is not set") ErrWireguardAllowedIPNotSet = errors.New("allowed IP is not set")

View File

@@ -1,7 +1,6 @@
package settings package settings
import ( import (
"errors"
"fmt" "fmt"
"net/netip" "net/netip"
"os" "os"
@@ -18,63 +17,34 @@ type Health struct {
// for the health check server. // for the health check server.
// It cannot be the empty string in the internal state. // It cannot be the empty string in the internal state.
ServerAddress string ServerAddress string
// TargetAddresses are the addresses (host or host:port) // TargetAddress is the address (host or host:port)
// to TCP TLS dial to periodically for the health check. // to TCP TLS dial to periodically for the health check.
// Addresses after the first one are used as fallbacks for retries. // It cannot be the empty string in the internal state.
// It cannot be empty in the internal state. TargetAddress string
TargetAddresses []string // ICMPTargetIP is the IP address to use for ICMP echo requests
// ICMPTargetIPs are the IP addresses to use for ICMP echo requests // in the health checker. It can be set to an unspecified address (0.0.0.0)
// in the health checker. The slice can be set to a single // such that the VPN server IP is used, which is also the default behavior.
// unspecified address (0.0.0.0) such that the VPN server IP is used, ICMPTargetIP netip.Addr
// although this can be less reliable. It defaults to [1.1.1.1,8.8.8.8],
// and cannot be left empty in the internal state.
ICMPTargetIPs []netip.Addr
// SmallCheckType is the type of small health check to perform.
// It can be "icmp" or "dns", and defaults to "icmp".
// Note it changes automatically to dns if icmp is not supported.
SmallCheckType string
// RestartVPN indicates whether to restart the VPN connection // RestartVPN indicates whether to restart the VPN connection
// when the healthcheck fails. // when the healthcheck fails.
RestartVPN *bool RestartVPN *bool
} }
var (
ErrICMPTargetIPNotValid = errors.New("ICMP target IP address is not valid")
ErrICMPTargetIPsNotCompatible = errors.New("ICMP target IP addresses are not compatible")
ErrSmallCheckTypeNotValid = errors.New("small check type is not valid")
)
func (h Health) Validate() (err error) { func (h Health) Validate() (err error) {
err = validate.ListeningAddress(h.ServerAddress, os.Getuid()) err = validate.ListeningAddress(h.ServerAddress, os.Getuid())
if err != nil { if err != nil {
return fmt.Errorf("server listening address is not valid: %w", err) return fmt.Errorf("server listening address is not valid: %w", err)
} }
for _, ip := range h.ICMPTargetIPs {
switch {
case !ip.IsValid():
return fmt.Errorf("%w: %s", ErrICMPTargetIPNotValid, ip)
case ip.IsUnspecified() && len(h.ICMPTargetIPs) > 1:
return fmt.Errorf("%w: only a single IP address must be set if it is to be unspecified",
ErrICMPTargetIPsNotCompatible)
}
}
err = validate.IsOneOf(h.SmallCheckType, "icmp", "dns")
if err != nil {
return fmt.Errorf("%w: %s", ErrSmallCheckTypeNotValid, err)
}
return nil return nil
} }
func (h *Health) copy() (copied Health) { func (h *Health) copy() (copied Health) {
return Health{ return Health{
ServerAddress: h.ServerAddress, ServerAddress: h.ServerAddress,
TargetAddresses: h.TargetAddresses, TargetAddress: h.TargetAddress,
ICMPTargetIPs: gosettings.CopySlice(h.ICMPTargetIPs), ICMPTargetIP: h.ICMPTargetIP,
SmallCheckType: h.SmallCheckType, RestartVPN: gosettings.CopyPointer(h.RestartVPN),
RestartVPN: gosettings.CopyPointer(h.RestartVPN),
} }
} }
@@ -83,20 +53,15 @@ func (h *Health) copy() (copied Health) {
// settings. // settings.
func (h *Health) OverrideWith(other Health) { func (h *Health) OverrideWith(other Health) {
h.ServerAddress = gosettings.OverrideWithComparable(h.ServerAddress, other.ServerAddress) h.ServerAddress = gosettings.OverrideWithComparable(h.ServerAddress, other.ServerAddress)
h.TargetAddresses = gosettings.OverrideWithSlice(h.TargetAddresses, other.TargetAddresses) h.TargetAddress = gosettings.OverrideWithComparable(h.TargetAddress, other.TargetAddress)
h.ICMPTargetIPs = gosettings.OverrideWithSlice(h.ICMPTargetIPs, other.ICMPTargetIPs) h.ICMPTargetIP = gosettings.OverrideWithComparable(h.ICMPTargetIP, other.ICMPTargetIP)
h.SmallCheckType = gosettings.OverrideWithComparable(h.SmallCheckType, other.SmallCheckType)
h.RestartVPN = gosettings.OverrideWithPointer(h.RestartVPN, other.RestartVPN) h.RestartVPN = gosettings.OverrideWithPointer(h.RestartVPN, other.RestartVPN)
} }
func (h *Health) SetDefaults() { func (h *Health) SetDefaults() {
h.ServerAddress = gosettings.DefaultComparable(h.ServerAddress, "127.0.0.1:9999") h.ServerAddress = gosettings.DefaultComparable(h.ServerAddress, "127.0.0.1:9999")
h.TargetAddresses = gosettings.DefaultSlice(h.TargetAddresses, []string{"cloudflare.com:443", "github.com:443"}) h.TargetAddress = gosettings.DefaultComparable(h.TargetAddress, "cloudflare.com:443")
h.ICMPTargetIPs = gosettings.DefaultSlice(h.ICMPTargetIPs, []netip.Addr{ h.ICMPTargetIP = gosettings.DefaultComparable(h.ICMPTargetIP, netip.IPv4Unspecified()) // use the VPN server IP
netip.AddrFrom4([4]byte{1, 1, 1, 1}),
netip.AddrFrom4([4]byte{8, 8, 8, 8}),
})
h.SmallCheckType = gosettings.DefaultComparable(h.SmallCheckType, "icmp")
h.RestartVPN = gosettings.DefaultPointer(h.RestartVPN, true) h.RestartVPN = gosettings.DefaultPointer(h.RestartVPN, true)
} }
@@ -107,37 +72,24 @@ func (h Health) String() string {
func (h Health) toLinesNode() (node *gotree.Node) { func (h Health) toLinesNode() (node *gotree.Node) {
node = gotree.New("Health settings:") node = gotree.New("Health settings:")
node.Appendf("Server listening address: %s", h.ServerAddress) node.Appendf("Server listening address: %s", h.ServerAddress)
targetAddrs := node.Appendf("Target addresses:") node.Appendf("Target address: %s", h.TargetAddress)
for _, targetAddr := range h.TargetAddresses { icmpTarget := "VPN server IP"
targetAddrs.Append(targetAddr) if !h.ICMPTargetIP.IsUnspecified() {
} icmpTarget = h.ICMPTargetIP.String()
switch h.SmallCheckType {
case "icmp":
icmpNode := node.Appendf("Small health check type: ICMP echo request")
if len(h.ICMPTargetIPs) == 1 && h.ICMPTargetIPs[0].IsUnspecified() {
icmpNode.Appendf("ICMP target IP: VPN server IP address")
} else {
icmpIPs := icmpNode.Appendf("ICMP target IPs:")
for _, ip := range h.ICMPTargetIPs {
icmpIPs.Append(ip.String())
}
}
case "dns":
node.Appendf("Small health check type: Plain DNS lookup over UDP")
} }
node.Appendf("ICMP target IP: %s", icmpTarget)
node.Appendf("Restart VPN on healthcheck failure: %s", gosettings.BoolToYesNo(h.RestartVPN)) node.Appendf("Restart VPN on healthcheck failure: %s", gosettings.BoolToYesNo(h.RestartVPN))
return node return node
} }
func (h *Health) Read(r *reader.Reader) (err error) { func (h *Health) Read(r *reader.Reader) (err error) {
h.ServerAddress = r.String("HEALTH_SERVER_ADDRESS") h.ServerAddress = r.String("HEALTH_SERVER_ADDRESS")
h.TargetAddresses = r.CSV("HEALTH_TARGET_ADDRESSES", h.TargetAddress = r.String("HEALTH_TARGET_ADDRESS",
reader.RetroKeys("HEALTH_ADDRESS_TO_PING", "HEALTH_TARGET_ADDRESS")) reader.RetroKeys("HEALTH_ADDRESS_TO_PING"))
h.ICMPTargetIPs, err = r.CSVNetipAddresses("HEALTH_ICMP_TARGET_IPS", reader.RetroKeys("HEALTH_ICMP_TARGET_IP")) h.ICMPTargetIP, err = r.NetipAddr("HEALTH_ICMP_TARGET_IP")
if err != nil { if err != nil {
return err return err
} }
h.SmallCheckType = r.String("HEALTH_SMALL_CHECK_TYPE")
h.RestartVPN, err = r.BoolPtr("HEALTH_RESTART_VPN") h.RestartVPN, err = r.BoolPtr("HEALTH_RESTART_VPN")
if err != nil { if err != nil {
return err return err

View File

@@ -2,7 +2,6 @@ package settings
import ( import (
"fmt" "fmt"
"net/netip"
"strings" "strings"
"github.com/qdm12/gluetun/internal/configuration/settings/helpers" "github.com/qdm12/gluetun/internal/configuration/settings/helpers"
@@ -25,12 +24,6 @@ type OpenVPNSelection struct {
// and can be udp or tcp. It cannot be the empty string // and can be udp or tcp. It cannot be the empty string
// in the internal state. // in the internal state.
Protocol string `json:"protocol"` Protocol string `json:"protocol"`
// EndpointIP is the server endpoint IP address.
// If set, it overrides any IP address from the picked
// built-in server connection. To indicate it should
// not be used, it should be set to [netip.IPv4Unspecified].
// It can never be the zero value in the internal state.
EndpointIP netip.Addr `json:"endpoint_ip"`
// CustomPort is the OpenVPN server endpoint port. // CustomPort is the OpenVPN server endpoint port.
// It can be set to 0 to indicate no custom port should // It can be set to 0 to indicate no custom port should
// be used. It cannot be nil in the internal state. // be used. It cannot be nil in the internal state.
@@ -149,7 +142,6 @@ func (o *OpenVPNSelection) copy() (copied OpenVPNSelection) {
return OpenVPNSelection{ return OpenVPNSelection{
ConfFile: gosettings.CopyPointer(o.ConfFile), ConfFile: gosettings.CopyPointer(o.ConfFile),
Protocol: o.Protocol, Protocol: o.Protocol,
EndpointIP: o.EndpointIP,
CustomPort: gosettings.CopyPointer(o.CustomPort), CustomPort: gosettings.CopyPointer(o.CustomPort),
PIAEncPreset: gosettings.CopyPointer(o.PIAEncPreset), PIAEncPreset: gosettings.CopyPointer(o.PIAEncPreset),
} }
@@ -159,14 +151,12 @@ func (o *OpenVPNSelection) overrideWith(other OpenVPNSelection) {
o.ConfFile = gosettings.OverrideWithPointer(o.ConfFile, other.ConfFile) o.ConfFile = gosettings.OverrideWithPointer(o.ConfFile, other.ConfFile)
o.Protocol = gosettings.OverrideWithComparable(o.Protocol, other.Protocol) o.Protocol = gosettings.OverrideWithComparable(o.Protocol, other.Protocol)
o.CustomPort = gosettings.OverrideWithPointer(o.CustomPort, other.CustomPort) o.CustomPort = gosettings.OverrideWithPointer(o.CustomPort, other.CustomPort)
o.EndpointIP = gosettings.OverrideWithValidator(o.EndpointIP, other.EndpointIP)
o.PIAEncPreset = gosettings.OverrideWithPointer(o.PIAEncPreset, other.PIAEncPreset) o.PIAEncPreset = gosettings.OverrideWithPointer(o.PIAEncPreset, other.PIAEncPreset)
} }
func (o *OpenVPNSelection) setDefaults(vpnProvider string) { func (o *OpenVPNSelection) setDefaults(vpnProvider string) {
o.ConfFile = gosettings.DefaultPointer(o.ConfFile, "") o.ConfFile = gosettings.DefaultPointer(o.ConfFile, "")
o.Protocol = gosettings.DefaultComparable(o.Protocol, constants.UDP) o.Protocol = gosettings.DefaultComparable(o.Protocol, constants.UDP)
o.EndpointIP = gosettings.DefaultValidator(o.EndpointIP, netip.IPv4Unspecified())
o.CustomPort = gosettings.DefaultPointer(o.CustomPort, 0) o.CustomPort = gosettings.DefaultPointer(o.CustomPort, 0)
var defaultEncPreset string var defaultEncPreset string
@@ -184,10 +174,6 @@ func (o OpenVPNSelection) toLinesNode() (node *gotree.Node) {
node = gotree.New("OpenVPN server selection settings:") node = gotree.New("OpenVPN server selection settings:")
node.Appendf("Protocol: %s", strings.ToUpper(o.Protocol)) node.Appendf("Protocol: %s", strings.ToUpper(o.Protocol))
if !o.EndpointIP.IsUnspecified() {
node.Appendf("Endpoint IP address: %s", o.EndpointIP)
}
if *o.CustomPort != 0 { if *o.CustomPort != 0 {
node.Appendf("Custom port: %d", *o.CustomPort) node.Appendf("Custom port: %d", *o.CustomPort)
} }
@@ -208,12 +194,6 @@ func (o *OpenVPNSelection) read(r *reader.Reader) (err error) {
o.Protocol = r.String("OPENVPN_PROTOCOL", reader.RetroKeys("PROTOCOL")) o.Protocol = r.String("OPENVPN_PROTOCOL", reader.RetroKeys("PROTOCOL"))
o.EndpointIP, err = r.NetipAddr("OPENVPN_ENDPOINT_IP",
reader.RetroKeys("OPENVPN_TARGET_IP", "VPN_ENDPOINT_IP"))
if err != nil {
return err
}
o.CustomPort, err = r.Uint16Ptr("OPENVPN_ENDPOINT_PORT", o.CustomPort, err = r.Uint16Ptr("OPENVPN_ENDPOINT_PORT",
reader.RetroKeys("PORT", "OPENVPN_PORT", "VPN_ENDPOINT_PORT")) reader.RetroKeys("PORT", "OPENVPN_PORT", "VPN_ENDPOINT_PORT"))
if err != nil { if err != nil {

View File

@@ -3,6 +3,7 @@ package settings
import ( import (
"errors" "errors"
"fmt" "fmt"
"net/netip"
"strings" "strings"
"github.com/qdm12/gluetun/internal/configuration/settings/helpers" "github.com/qdm12/gluetun/internal/configuration/settings/helpers"
@@ -21,6 +22,12 @@ type ServerSelection struct {
// or 'wireguard'. It cannot be the empty string // or 'wireguard'. It cannot be the empty string
// in the internal state. // in the internal state.
VPN string `json:"vpn"` VPN string `json:"vpn"`
// TargetIP is the server endpoint IP address to use.
// It will override any IP address from the picked
// built-in server. It cannot be the empty value in the internal
// state, and can be set to the unspecified address to indicate
// there is not target IP address to use.
TargetIP netip.Addr `json:"target_ip"`
// Countries is the list of countries to filter VPN servers with. // Countries is the list of countries to filter VPN servers with.
Countries []string `json:"countries"` Countries []string `json:"countries"`
// Categories is the list of categories to filter VPN servers with. // Categories is the list of categories to filter VPN servers with.
@@ -292,6 +299,7 @@ func validateFeatureFilters(settings ServerSelection, vpnServiceProvider string)
func (ss *ServerSelection) copy() (copied ServerSelection) { func (ss *ServerSelection) copy() (copied ServerSelection) {
return ServerSelection{ return ServerSelection{
VPN: ss.VPN, VPN: ss.VPN,
TargetIP: ss.TargetIP,
Countries: gosettings.CopySlice(ss.Countries), Countries: gosettings.CopySlice(ss.Countries),
Categories: gosettings.CopySlice(ss.Categories), Categories: gosettings.CopySlice(ss.Categories),
Regions: gosettings.CopySlice(ss.Regions), Regions: gosettings.CopySlice(ss.Regions),
@@ -315,6 +323,7 @@ func (ss *ServerSelection) copy() (copied ServerSelection) {
func (ss *ServerSelection) overrideWith(other ServerSelection) { func (ss *ServerSelection) overrideWith(other ServerSelection) {
ss.VPN = gosettings.OverrideWithComparable(ss.VPN, other.VPN) ss.VPN = gosettings.OverrideWithComparable(ss.VPN, other.VPN)
ss.TargetIP = gosettings.OverrideWithValidator(ss.TargetIP, other.TargetIP)
ss.Countries = gosettings.OverrideWithSlice(ss.Countries, other.Countries) ss.Countries = gosettings.OverrideWithSlice(ss.Countries, other.Countries)
ss.Categories = gosettings.OverrideWithSlice(ss.Categories, other.Categories) ss.Categories = gosettings.OverrideWithSlice(ss.Categories, other.Categories)
ss.Regions = gosettings.OverrideWithSlice(ss.Regions, other.Regions) ss.Regions = gosettings.OverrideWithSlice(ss.Regions, other.Regions)
@@ -337,6 +346,7 @@ func (ss *ServerSelection) overrideWith(other ServerSelection) {
func (ss *ServerSelection) setDefaults(vpnProvider string, portForwardingEnabled bool) { func (ss *ServerSelection) setDefaults(vpnProvider string, portForwardingEnabled bool) {
ss.VPN = gosettings.DefaultComparable(ss.VPN, vpn.OpenVPN) ss.VPN = gosettings.DefaultComparable(ss.VPN, vpn.OpenVPN)
ss.TargetIP = gosettings.DefaultValidator(ss.TargetIP, netip.IPv4Unspecified())
ss.OwnedOnly = gosettings.DefaultPointer(ss.OwnedOnly, false) ss.OwnedOnly = gosettings.DefaultPointer(ss.OwnedOnly, false)
ss.FreeOnly = gosettings.DefaultPointer(ss.FreeOnly, false) ss.FreeOnly = gosettings.DefaultPointer(ss.FreeOnly, false)
ss.PremiumOnly = gosettings.DefaultPointer(ss.PremiumOnly, false) ss.PremiumOnly = gosettings.DefaultPointer(ss.PremiumOnly, false)
@@ -358,6 +368,9 @@ func (ss ServerSelection) String() string {
func (ss ServerSelection) toLinesNode() (node *gotree.Node) { func (ss ServerSelection) toLinesNode() (node *gotree.Node) {
node = gotree.New("Server selection settings:") node = gotree.New("Server selection settings:")
node.Appendf("VPN type: %s", ss.VPN) node.Appendf("VPN type: %s", ss.VPN)
if !ss.TargetIP.IsUnspecified() {
node.Appendf("Target IP address: %s", ss.TargetIP)
}
if len(ss.Countries) > 0 { if len(ss.Countries) > 0 {
node.Appendf("Countries: %s", strings.Join(ss.Countries, ", ")) node.Appendf("Countries: %s", strings.Join(ss.Countries, ", "))
@@ -448,6 +461,12 @@ func (ss *ServerSelection) read(r *reader.Reader,
) (err error) { ) (err error) {
ss.VPN = vpnType ss.VPN = vpnType
ss.TargetIP, err = r.NetipAddr("OPENVPN_ENDPOINT_IP",
reader.RetroKeys("OPENVPN_TARGET_IP", "VPN_ENDPOINT_IP"))
if err != nil {
return err
}
countriesRetroKeys := []string{"COUNTRY"} countriesRetroKeys := []string{"COUNTRY"}
if vpnProvider == providers.Cyberghost { if vpnProvider == providers.Cyberghost {
countriesRetroKeys = append(countriesRetroKeys, "REGION") countriesRetroKeys = append(countriesRetroKeys, "REGION")

View File

@@ -57,13 +57,8 @@ func Test_Settings_String(t *testing.T) {
| └── Log level: INFO | └── Log level: INFO
├── Health settings: ├── Health settings:
| ├── Server listening address: 127.0.0.1:9999 | ├── Server listening address: 127.0.0.1:9999
| ├── Target addresses: | ├── Target address: cloudflare.com:443
| | ├── cloudflare.com:443 | ├── ICMP target IP: VPN server IP
| | └── github.com:443
| ├── Small health check type: ICMP echo request
| | └── ICMP target IPs:
| | ├── 1.1.1.1
| | └── 8.8.8.8
| └── Restart VPN on healthcheck failure: yes | └── Restart VPN on healthcheck failure: yes
├── Shadowsocks server settings: ├── Shadowsocks server settings:
| └── Enabled: no | └── Enabled: no

View File

@@ -32,8 +32,8 @@ type Updater struct {
// Providers is the list of VPN service providers // Providers is the list of VPN service providers
// to update server information for. // to update server information for.
Providers []string Providers []string
// ProtonEmail is the email to authenticate with the Proton API. // ProtonUsername is the username to authenticate with the Proton API.
ProtonEmail *string ProtonUsername *string
// ProtonPassword is the password to authenticate with the Proton API. // ProtonPassword is the password to authenticate with the Proton API.
ProtonPassword *string ProtonPassword *string
} }
@@ -58,11 +58,11 @@ func (u Updater) Validate() (err error) {
} }
if provider == providers.Protonvpn { if provider == providers.Protonvpn {
authenticatedAPI := *u.ProtonEmail != "" || *u.ProtonPassword != "" authenticatedAPI := *u.ProtonUsername != "" || *u.ProtonPassword != ""
if authenticatedAPI { if authenticatedAPI {
switch { switch {
case *u.ProtonEmail == "": case *u.ProtonUsername == "":
return fmt.Errorf("%w", ErrUpdaterProtonEmailMissing) return fmt.Errorf("%w", ErrUpdaterProtonUsernameMissing)
case *u.ProtonPassword == "": case *u.ProtonPassword == "":
return fmt.Errorf("%w", ErrUpdaterProtonPasswordMissing) return fmt.Errorf("%w", ErrUpdaterProtonPasswordMissing)
} }
@@ -79,7 +79,7 @@ func (u *Updater) copy() (copied Updater) {
DNSAddress: u.DNSAddress, DNSAddress: u.DNSAddress,
MinRatio: u.MinRatio, MinRatio: u.MinRatio,
Providers: gosettings.CopySlice(u.Providers), Providers: gosettings.CopySlice(u.Providers),
ProtonEmail: gosettings.CopyPointer(u.ProtonEmail), ProtonUsername: gosettings.CopyPointer(u.ProtonUsername),
ProtonPassword: gosettings.CopyPointer(u.ProtonPassword), ProtonPassword: gosettings.CopyPointer(u.ProtonPassword),
} }
} }
@@ -92,7 +92,7 @@ func (u *Updater) overrideWith(other Updater) {
u.DNSAddress = gosettings.OverrideWithComparable(u.DNSAddress, other.DNSAddress) u.DNSAddress = gosettings.OverrideWithComparable(u.DNSAddress, other.DNSAddress)
u.MinRatio = gosettings.OverrideWithComparable(u.MinRatio, other.MinRatio) u.MinRatio = gosettings.OverrideWithComparable(u.MinRatio, other.MinRatio)
u.Providers = gosettings.OverrideWithSlice(u.Providers, other.Providers) u.Providers = gosettings.OverrideWithSlice(u.Providers, other.Providers)
u.ProtonEmail = gosettings.OverrideWithPointer(u.ProtonEmail, other.ProtonEmail) u.ProtonUsername = gosettings.OverrideWithPointer(u.ProtonUsername, other.ProtonUsername)
u.ProtonPassword = gosettings.OverrideWithPointer(u.ProtonPassword, other.ProtonPassword) u.ProtonPassword = gosettings.OverrideWithPointer(u.ProtonPassword, other.ProtonPassword)
} }
@@ -110,7 +110,7 @@ func (u *Updater) SetDefaults(vpnProvider string) {
} }
// Set these to empty strings to avoid nil pointer panics // Set these to empty strings to avoid nil pointer panics
u.ProtonEmail = gosettings.DefaultPointer(u.ProtonEmail, "") u.ProtonUsername = gosettings.DefaultPointer(u.ProtonUsername, "")
u.ProtonPassword = gosettings.DefaultPointer(u.ProtonPassword, "") u.ProtonPassword = gosettings.DefaultPointer(u.ProtonPassword, "")
} }
@@ -129,7 +129,7 @@ func (u Updater) toLinesNode() (node *gotree.Node) {
node.Appendf("Minimum ratio: %.1f", u.MinRatio) node.Appendf("Minimum ratio: %.1f", u.MinRatio)
node.Appendf("Providers to update: %s", strings.Join(u.Providers, ", ")) node.Appendf("Providers to update: %s", strings.Join(u.Providers, ", "))
if slices.Contains(u.Providers, providers.Protonvpn) { if slices.Contains(u.Providers, providers.Protonvpn) {
node.Appendf("Proton API email: %s", *u.ProtonEmail) node.Appendf("Proton API username: %s", *u.ProtonUsername)
node.Appendf("Proton API password: %s", gosettings.ObfuscateKey(*u.ProtonPassword)) node.Appendf("Proton API password: %s", gosettings.ObfuscateKey(*u.ProtonPassword))
} }
@@ -154,13 +154,11 @@ func (u *Updater) read(r *reader.Reader) (err error) {
u.Providers = r.CSV("UPDATER_VPN_SERVICE_PROVIDERS") u.Providers = r.CSV("UPDATER_VPN_SERVICE_PROVIDERS")
u.ProtonEmail = r.Get("UPDATER_PROTONVPN_EMAIL") u.ProtonUsername = r.Get("UPDATER_PROTONVPN_USERNAME")
if u.ProtonEmail == nil { if u.ProtonUsername != nil {
protonUsername := r.String("UPDATER_PROTONVPN_USERNAME", reader.IsRetro("UPDATER_PROTONVPN_EMAIL")) // Enforce to use the username not the email address
if protonUsername != "" { *u.ProtonUsername = strings.TrimSuffix(*u.ProtonUsername, "@protonmail.com")
protonEmail := protonUsername + "@protonmail.com" *u.ProtonUsername = strings.TrimSuffix(*u.ProtonUsername, "@proton.me")
u.ProtonEmail = &protonEmail
}
} }
u.ProtonPassword = r.Get("UPDATER_PROTONVPN_PASSWORD") u.ProtonPassword = r.Get("UPDATER_PROTONVPN_PASSWORD")

View File

@@ -14,11 +14,11 @@ import (
type WireguardSelection struct { type WireguardSelection struct {
// EndpointIP is the server endpoint IP address. // EndpointIP is the server endpoint IP address.
// It is notably required with the custom provider. // It is only used with VPN providers generating Wireguard
// Otherwise it overrides any IP address from the picked // configurations specific to each server and user.
// built-in server connection. To indicate it should // To indicate it should not be used, it should be set
// not be used, it should be set to [netip.IPv4Unspecified]. // to netip.IPv4Unspecified(). It can never be the zero value
// It can never be the zero value in the internal state. // in the internal state.
EndpointIP netip.Addr `json:"endpoint_ip"` EndpointIP netip.Addr `json:"endpoint_ip"`
// EndpointPort is a the server port to use for the VPN server. // EndpointPort is a the server port to use for the VPN server.
// It is optional for VPN providers IVPN, Mullvad, Surfshark // It is optional for VPN providers IVPN, Mullvad, Surfshark

View File

@@ -45,6 +45,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
if err == nil { if err == nil {
l.backoffTime = defaultBackoffTime l.backoffTime = defaultBackoffTime
l.logger.Info("ready") l.logger.Info("ready")
l.signalOrSetStatus(constants.Running)
break break
} }
@@ -61,7 +62,6 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
l.logAndWait(ctx, err) l.logAndWait(ctx, err)
settings = l.GetSettings() settings = l.GetSettings()
} }
l.signalOrSetStatus(constants.Running)
settings = l.GetSettings() settings = l.GetSettings()
if !*settings.KeepNameserver && !*settings.ServerEnabled { if !*settings.KeepNameserver && !*settings.ServerEnabled {
@@ -82,19 +82,15 @@ func (l *Loop) runWait(ctx context.Context, runError <-chan error) (exitLoop boo
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
if !*l.GetSettings().KeepNameserver { l.stopServer()
l.stopServer() // TODO revert OS and Go nameserver when exiting
// TODO revert OS and Go nameserver when exiting
}
return true return true
case <-l.stop: case <-l.stop:
l.userTrigger = true l.userTrigger = true
l.logger.Info("stopping") l.logger.Info("stopping")
if !*l.GetSettings().KeepNameserver { const fallback = false
const fallback = false l.useUnencryptedDNS(fallback)
l.useUnencryptedDNS(fallback) l.stopServer()
l.stopServer()
}
l.stopped <- struct{}{} l.stopped <- struct{}{}
case <-l.start: case <-l.start:
l.userTrigger = true l.userTrigger = true

View File

@@ -16,16 +16,16 @@ import (
) )
type Checker struct { type Checker struct {
tlsDialAddrs []string tlsDialAddr string
dialer *net.Dialer dialer *net.Dialer
echoer *icmp.Echoer echoer *icmp.Echoer
dnsClient *dns.Client dnsClient *dns.Client
logger Logger logger Logger
icmpTargetIPs []netip.Addr icmpTarget netip.Addr
smallCheckType string configMutex sync.Mutex
configMutex sync.Mutex
icmpNotPermitted bool icmpNotPermitted bool
smallCheckName string
// Internal periodic service signals // Internal periodic service signals
stop context.CancelFunc stop context.CancelFunc
@@ -45,37 +45,35 @@ func NewChecker(logger Logger) *Checker {
} }
} }
// SetConfig sets the TCP+TLS dial addresses, the ICMP echo IP address // SetConfig sets the TCP+TLS dial address and the ICMP echo IP address
// to target and the desired small check type (dns or icmp). // to target by the [Checker].
// This function MUST be called before calling [Checker.Start]. // This function MUST be called before calling [Checker.Start].
func (c *Checker) SetConfig(tlsDialAddrs []string, icmpTargets []netip.Addr, func (c *Checker) SetConfig(tlsDialAddr string, icmpTarget netip.Addr) {
smallCheckType string,
) {
c.configMutex.Lock() c.configMutex.Lock()
defer c.configMutex.Unlock() defer c.configMutex.Unlock()
c.tlsDialAddrs = tlsDialAddrs c.tlsDialAddr = tlsDialAddr
c.icmpTargetIPs = icmpTargets c.icmpTarget = icmpTarget
c.smallCheckType = smallCheckType
} }
// Start starts the checker by first running a blocking 6s-timed TCP+TLS check, // Start starts the checker by first running a blocking 2s-timed TCP+TLS check,
// and, on success, starts the periodic checks in a separate goroutine: // and, on success, starts the periodic checks in a separate goroutine:
// - a "small" ICMP echo check every minute // - a "small" ICMP echo check every 15 seconds
// - a "full" TCP+TLS check every 5 minutes // - a "full" TCP+TLS check every 5 minutes
// It returns a channel `runError` that receives an error (nil or not) when a periodic check is performed. // It returns a channel `runError` that receives an error (nil or not) when a periodic check is performed.
// It returns an error if the initial TCP+TLS check fails. // It returns an error if the initial TCP+TLS check fails.
// The Checker has to be ultimately stopped by calling [Checker.Stop]. // The Checker has to be ultimately stopped by calling [Checker.Stop].
func (c *Checker) Start(ctx context.Context) (runError <-chan error, err error) { func (c *Checker) Start(ctx context.Context) (runError <-chan error, err error) {
if len(c.tlsDialAddrs) == 0 || len(c.icmpTargetIPs) == 0 || c.smallCheckType == "" { if c.tlsDialAddr == "" || c.icmpTarget.IsUnspecified() {
panic("call Checker.SetConfig with non empty values before Checker.Start") panic("call Checker.SetConfig with non empty values before Checker.Start")
} }
if c.icmpNotPermitted { // connection isn't under load yet when the checker starts, so a short
// restore forced check type to dns if icmp was found to be not permitted // 6 seconds timeout suffices and provides quick enough feedback that
c.smallCheckType = smallCheckDNS // the new connection is not working.
} const timeout = 6 * time.Second
tcpTLSCheckCtx, tcpTLSCheckCancel := context.WithTimeout(ctx, timeout)
err = c.startupCheck(ctx) err = tcpTLSCheck(tcpTLSCheckCtx, c.dialer, c.tlsDialAddr)
tcpTLSCheckCancel()
if err != nil { if err != nil {
return nil, fmt.Errorf("startup check: %w", err) return nil, fmt.Errorf("startup check: %w", err)
} }
@@ -85,6 +83,7 @@ func (c *Checker) Start(ctx context.Context) (runError <-chan error, err error)
c.stop = cancel c.stop = cancel
done := make(chan struct{}) done := make(chan struct{})
c.done = done c.done = done
c.smallCheckName = "ICMP echo"
const smallCheckPeriod = time.Minute const smallCheckPeriod = time.Minute
smallCheckTimer := time.NewTimer(smallCheckPeriod) smallCheckTimer := time.NewTimer(smallCheckPeriod)
const fullCheckPeriod = 5 * time.Minute const fullCheckPeriod = 5 * time.Minute
@@ -124,56 +123,43 @@ func (c *Checker) Start(ctx context.Context) (runError <-chan error, err error)
func (c *Checker) Stop() error { func (c *Checker) Stop() error {
c.stop() c.stop()
<-c.done <-c.done
c.tlsDialAddrs = nil c.icmpTarget = netip.Addr{}
c.icmpTargetIPs = nil
c.smallCheckType = ""
return nil return nil
} }
func (c *Checker) smallPeriodicCheck(ctx context.Context) error { func (c *Checker) smallPeriodicCheck(ctx context.Context) error {
c.configMutex.Lock() c.configMutex.Lock()
icmpTargetIPs := make([]netip.Addr, len(c.icmpTargetIPs)) ip := c.icmpTarget
copy(icmpTargetIPs, c.icmpTargetIPs)
c.configMutex.Unlock() c.configMutex.Unlock()
tryTimeouts := []time.Duration{ const maxTries = 3
5 * time.Second, const timeout = 10 * time.Second
5 * time.Second, const extraTryTime = 10 * time.Second // 10s added for each subsequent retry
5 * time.Second, check := func(ctx context.Context) error {
10 * time.Second, if c.icmpNotPermitted {
10 * time.Second,
10 * time.Second,
15 * time.Second,
15 * time.Second,
15 * time.Second,
30 * time.Second,
}
check := func(ctx context.Context, try int) error {
if c.smallCheckType == smallCheckDNS {
return c.dnsClient.Check(ctx) return c.dnsClient.Check(ctx)
} }
ip := icmpTargetIPs[try%len(icmpTargetIPs)]
err := c.echoer.Echo(ctx, ip) err := c.echoer.Echo(ctx, ip)
if errors.Is(err, icmp.ErrNotPermitted) { if errors.Is(err, icmp.ErrNotPermitted) {
c.icmpNotPermitted = true c.icmpNotPermitted = true
c.smallCheckType = smallCheckDNS c.smallCheckName = "plain DNS over UDP"
c.logger.Infof("%s; permanently falling back to %s checks", c.logger.Infof("%s; permanently falling back to %s checks.", c.smallCheckName, err)
smallCheckTypeToString(c.smallCheckType), err)
return c.dnsClient.Check(ctx) return c.dnsClient.Check(ctx)
} }
return err return err
} }
return withRetries(ctx, tryTimeouts, c.logger, smallCheckTypeToString(c.smallCheckType), check) return withRetries(ctx, maxTries, timeout, extraTryTime, c.logger, c.smallCheckName, check)
} }
func (c *Checker) fullPeriodicCheck(ctx context.Context) error { func (c *Checker) fullPeriodicCheck(ctx context.Context) error {
const maxTries = 2
// 20s timeout in case the connection is under stress // 20s timeout in case the connection is under stress
// See https://github.com/qdm12/gluetun/issues/2270 // See https://github.com/qdm12/gluetun/issues/2270
tryTimeouts := []time.Duration{10 * time.Second, 15 * time.Second, 30 * time.Second} const timeout = 20 * time.Second
check := func(ctx context.Context, try int) error { const extraTryTime = 10 * time.Second // 10s added for each subsequent retry
tlsDialAddr := c.tlsDialAddrs[try%len(c.tlsDialAddrs)] check := func(ctx context.Context) error {
return tcpTLSCheck(ctx, c.dialer, tlsDialAddr) return tcpTLSCheck(ctx, c.dialer, c.tlsDialAddr)
} }
return withRetries(ctx, tryTimeouts, c.logger, "TCP+TLS dial", check) return withRetries(ctx, maxTries, timeout, extraTryTime, c.logger, "TCP+TLS dial", check)
} }
func tcpTLSCheck(ctx context.Context, dialer *net.Dialer, targetAddress string) error { func tcpTLSCheck(ctx context.Context, dialer *net.Dialer, targetAddress string) error {
@@ -232,19 +218,15 @@ func makeAddressToDial(address string) (addressToDial string, err error) {
var ErrAllCheckTriesFailed = errors.New("all check tries failed") var ErrAllCheckTriesFailed = errors.New("all check tries failed")
func withRetries(ctx context.Context, tryTimeouts []time.Duration, func withRetries(ctx context.Context, maxTries uint, tryTimeout, extraTryTime time.Duration,
logger Logger, checkName string, check func(ctx context.Context, try int) error, logger Logger, checkName string, check func(ctx context.Context) error,
) error { ) error {
maxTries := len(tryTimeouts) try := uint(0)
type errData struct { var errs []error
err error for {
durationMS int64 timeout := tryTimeout + time.Duration(try)*extraTryTime //nolint:gosec
}
errs := make([]errData, maxTries)
for i, timeout := range tryTimeouts {
start := time.Now()
checkCtx, cancel := context.WithTimeout(ctx, timeout) checkCtx, cancel := context.WithTimeout(ctx, timeout)
err := check(checkCtx, i) err := check(checkCtx)
cancel() cancel()
switch { switch {
case err == nil: case err == nil:
@@ -252,75 +234,17 @@ func withRetries(ctx context.Context, tryTimeouts []time.Duration,
case ctx.Err() != nil: case ctx.Err() != nil:
return fmt.Errorf("%s: %w", checkName, ctx.Err()) return fmt.Errorf("%s: %w", checkName, ctx.Err())
} }
logger.Debugf("%s attempt %d/%d failed: %s", checkName, i+1, maxTries, err) logger.Debugf("%s attempt %d/%d failed: %s", checkName, try+1, maxTries, err)
errs[i].err = err
errs[i].durationMS = time.Since(start).Round(time.Millisecond).Milliseconds()
}
errStrings := make([]string, len(errs))
for i, err := range errs {
errStrings[i] = fmt.Sprintf("attempt %d (%dms): %s", i+1, err.durationMS, err.err)
}
return fmt.Errorf("%w: %s", ErrAllCheckTriesFailed, strings.Join(errStrings, ", "))
}
func (c *Checker) startupCheck(ctx context.Context) error {
// connection isn't under load yet when the checker starts, so a short
// 6 seconds timeout suffices and provides quick enough feedback that
// the new connection is not working. However, since the addresses to dial
// may be multiple, we run the check in parallel. If any succeeds, the check passes.
// This is to prevent false negatives at startup, if one of the addresses is down
// for external reasons.
const timeout = 6 * time.Second
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
errCh := make(chan error)
for _, address := range c.tlsDialAddrs {
go func(addr string) {
err := tcpTLSCheck(ctx, c.dialer, addr)
errCh <- err
}(address)
}
errs := make([]error, 0, len(c.tlsDialAddrs))
success := false
for range c.tlsDialAddrs {
err := <-errCh
if err == nil {
success = true
cancel()
continue
} else if success {
continue // ignore canceled errors after success
}
c.logger.Debugf("startup check parallel attempt failed: %s", err)
errs = append(errs, err) errs = append(errs, err)
} try++
if success { if try < maxTries {
return nil continue
} }
errStrings := make([]string, len(errs))
errStrings := make([]string, len(errs)) for i, err := range errs {
for i, err := range errs { errStrings[i] = fmt.Sprintf("attempt %d: %s", i+1, err.Error())
errStrings[i] = fmt.Sprintf("parallel attempt %d/%d failed: %s", i+1, len(errs), err) }
} return fmt.Errorf("%w: after %d %s attempts (%s)",
return fmt.Errorf("%w: %s", ErrAllCheckTriesFailed, strings.Join(errStrings, ", ")) ErrAllCheckTriesFailed, maxTries, checkName, strings.Join(errStrings, "; "))
}
const (
smallCheckDNS = "dns"
smallCheckICMP = "icmp"
)
func smallCheckTypeToString(smallCheckType string) string {
switch smallCheckType {
case smallCheckICMP:
return "ICMP echo"
case smallCheckDNS:
return "plain DNS over UDP"
default:
panic("unknown small check type: " + smallCheckType)
} }
} }

View File

@@ -18,11 +18,11 @@ func Test_Checker_fullcheck(t *testing.T) {
t.Parallel() t.Parallel()
dialer := &net.Dialer{} dialer := &net.Dialer{}
addresses := []string{"badaddress:9876", "cloudflare.com:443", "google.com:443"} const address = "cloudflare.com:443"
checker := &Checker{ checker := &Checker{
dialer: dialer, dialer: dialer,
tlsDialAddrs: addresses, tlsDialAddr: address,
} }
canceledCtx, cancel := context.WithCancel(context.Background()) canceledCtx, cancel := context.WithCancel(context.Background())
@@ -52,8 +52,8 @@ func Test_Checker_fullcheck(t *testing.T) {
dialer := &net.Dialer{} dialer := &net.Dialer{}
checker := &Checker{ checker := &Checker{
dialer: dialer, dialer: dialer,
tlsDialAddrs: []string{listeningAddress.String()}, tlsDialAddr: listeningAddress.String(),
} }
err = checker.fullPeriodicCheck(ctx) err = checker.fullPeriodicCheck(ctx)

View File

@@ -56,10 +56,10 @@ func (c *Client) Check(ctx context.Context) error {
switch { switch {
case err != nil: case err != nil:
c.dnsIPIndex = (c.dnsIPIndex + 1) % len(c.serverAddrs) c.dnsIPIndex = (c.dnsIPIndex + 1) % len(c.serverAddrs)
return fmt.Errorf("with DNS server %s: %w", dnsAddr, err) return err
case len(ips) == 0: case len(ips) == 0:
c.dnsIPIndex = (c.dnsIPIndex + 1) % len(c.serverAddrs) c.dnsIPIndex = (c.dnsIPIndex + 1) % len(c.serverAddrs)
return fmt.Errorf("with DNS server %s: %w", dnsAddr, ErrLookupNoIPs) return fmt.Errorf("%w", ErrLookupNoIPs)
default: default:
return nil return nil
} }

View File

@@ -82,20 +82,20 @@ func (i *Echoer) Echo(ctx context.Context, ip netip.Addr) (err error) {
if strings.HasSuffix(err.Error(), "sendto: operation not permitted") { if strings.HasSuffix(err.Error(), "sendto: operation not permitted") {
err = fmt.Errorf("%w", ErrNotPermitted) err = fmt.Errorf("%w", ErrNotPermitted)
} }
return fmt.Errorf("writing ICMP message to %s: %w", ip, err) return fmt.Errorf("writing ICMP message: %w", err)
} }
receivedData, err := receiveEchoReply(conn, id, i.buffer, ipVersion, i.logger) receivedData, err := receiveEchoReply(conn, id, i.buffer, ipVersion, i.logger)
if err != nil { if err != nil {
if errors.Is(err, net.ErrClosed) && ctx.Err() != nil { if errors.Is(err, net.ErrClosed) && ctx.Err() != nil {
return fmt.Errorf("%w from %s", ErrTimedOut, ip) return fmt.Errorf("%w", ErrTimedOut)
} }
return fmt.Errorf("receiving ICMP echo reply from %s: %w", ip, err) return fmt.Errorf("receiving ICMP echo reply: %w", err)
} }
sentData := message.Body.(*icmp.Echo).Data //nolint:forcetypeassert sentData := message.Body.(*icmp.Echo).Data //nolint:forcetypeassert
if !bytes.Equal(receivedData, sentData) { if !bytes.Equal(receivedData, sentData) {
return fmt.Errorf("%w: sent %x to %s and received %x", ErrICMPEchoDataMismatch, sentData, ip, receivedData) return fmt.Errorf("%w: sent %x and received %x", ErrICMPEchoDataMismatch, sentData, receivedData)
} }
return nil return nil

View File

@@ -10,7 +10,7 @@ import (
) )
func runCommand(ctx context.Context, cmder Cmder, logger Logger, func runCommand(ctx context.Context, cmder Cmder, logger Logger,
commandTemplate string, ports []uint16, vpnInterface string, commandTemplate string, ports []uint16,
) (err error) { ) (err error) {
portStrings := make([]string, len(ports)) portStrings := make([]string, len(ports))
for i, port := range ports { for i, port := range ports {
@@ -18,8 +18,6 @@ func runCommand(ctx context.Context, cmder Cmder, logger Logger,
} }
portsString := strings.Join(portStrings, ",") portsString := strings.Join(portStrings, ",")
commandString := strings.ReplaceAll(commandTemplate, "{{PORTS}}", portsString) commandString := strings.ReplaceAll(commandTemplate, "{{PORTS}}", portsString)
commandString = strings.ReplaceAll(commandString, "{{PORT}}", portStrings[0])
commandString = strings.ReplaceAll(commandString, "{{VPN_INTERFACE}}", vpnInterface)
args, err := command.Split(commandString) args, err := command.Split(commandString)
if err != nil { if err != nil {
return fmt.Errorf("parsing command: %w", err) return fmt.Errorf("parsing command: %w", err)

View File

@@ -17,13 +17,12 @@ func Test_Service_runCommand(t *testing.T) {
ctx := context.Background() ctx := context.Background()
cmder := command.New() cmder := command.New()
const commandTemplate = `/bin/sh -c "echo {{PORTS}}-{{PORT}}-{{VPN_INTERFACE}}"` const commandTemplate = `/bin/sh -c "echo {{PORTS}}"`
ports := []uint16{1234, 5678} ports := []uint16{1234, 5678}
const vpnInterface = "tun0"
logger := NewMockLogger(ctrl) logger := NewMockLogger(ctrl)
logger.EXPECT().Info("1234,5678-1234-tun0") logger.EXPECT().Info("1234,5678")
err := runCommand(ctx, cmder, logger, commandTemplate, ports, vpnInterface) err := runCommand(ctx, cmder, logger, commandTemplate, ports)
require.NoError(t, err) require.NoError(t, err)
} }

View File

@@ -74,7 +74,7 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error)
s.portMutex.Unlock() s.portMutex.Unlock()
if s.settings.UpCommand != "" { if s.settings.UpCommand != "" {
err = runCommand(ctx, s.cmder, s.logger, s.settings.UpCommand, ports, s.settings.Interface) err = runCommand(ctx, s.cmder, s.logger, s.settings.UpCommand, ports)
if err != nil { if err != nil {
err = fmt.Errorf("running up command: %w", err) err = fmt.Errorf("running up command: %w", err)
s.logger.Error(err.Error()) s.logger.Error(err.Error())

View File

@@ -34,7 +34,7 @@ func (s *Service) cleanup() (err error) {
const downTimeout = 60 * time.Second const downTimeout = 60 * time.Second
ctx, cancel := context.WithTimeout(context.Background(), downTimeout) ctx, cancel := context.WithTimeout(context.Background(), downTimeout)
defer cancel() defer cancel()
err = runCommand(ctx, s.cmder, s.logger, s.settings.DownCommand, s.ports, s.settings.Interface) err = runCommand(ctx, s.cmder, s.logger, s.settings.DownCommand, s.ports)
if err != nil { if err != nil {
err = fmt.Errorf("running down command: %w", err) err = fmt.Errorf("running down command: %w", err)
s.logger.Error(err.Error()) s.logger.Error(err.Error())

View File

@@ -18,12 +18,12 @@ type Provider struct {
func New(storage common.Storage, randSource rand.Source, func New(storage common.Storage, randSource rand.Source,
client *http.Client, updaterWarner common.Warner, client *http.Client, updaterWarner common.Warner,
email, password string, username, password string,
) *Provider { ) *Provider {
return &Provider{ return &Provider{
storage: storage, storage: storage,
randSource: randSource, randSource: randSource,
Fetcher: updater.New(client, updaterWarner, email, password), Fetcher: updater.New(client, updaterWarner, username, password),
} }
} }

View File

@@ -76,7 +76,7 @@ func (c *apiClient) setHeaders(request *http.Request, cookie cookie) {
// authenticate performs the full Proton authentication flow // authenticate performs the full Proton authentication flow
// to obtain an authenticated cookie (uid, token and session ID). // to obtain an authenticated cookie (uid, token and session ID).
func (c *apiClient) authenticate(ctx context.Context, email, password string, func (c *apiClient) authenticate(ctx context.Context, username, password string,
) (authCookie cookie, err error) { ) (authCookie cookie, err error) {
sessionID, err := c.getSessionID(ctx) sessionID, err := c.getSessionID(ctx)
if err != nil { if err != nil {
@@ -98,8 +98,8 @@ func (c *apiClient) authenticate(ctx context.Context, email, password string,
token: cookieToken, token: cookieToken,
sessionID: sessionID, sessionID: sessionID,
} }
username, modulusPGPClearSigned, serverEphemeralBase64, saltBase64, modulusPGPClearSigned, serverEphemeralBase64, saltBase64,
srpSessionHex, version, err := c.authInfo(ctx, email, unauthCookie) srpSessionHex, version, err := c.authInfo(ctx, username, unauthCookie)
if err != nil { if err != nil {
return cookie{}, fmt.Errorf("getting auth information: %w", err) return cookie{}, fmt.Errorf("getting auth information: %w", err)
} }
@@ -118,7 +118,7 @@ func (c *apiClient) authenticate(ctx context.Context, email, password string,
return cookie{}, fmt.Errorf("generating SRP proofs: %w", err) return cookie{}, fmt.Errorf("generating SRP proofs: %w", err)
} }
authCookie, err = c.auth(ctx, unauthCookie, email, srpSessionHex, proofs) authCookie, err = c.auth(ctx, unauthCookie, username, srpSessionHex, proofs)
if err != nil { if err != nil {
return cookie{}, fmt.Errorf("authentifying: %w", err) return cookie{}, fmt.Errorf("authentifying: %w", err)
} }
@@ -299,45 +299,48 @@ func (c *apiClient) cookieToken(ctx context.Context, sessionID, tokenType, acces
return "", fmt.Errorf("%w", ErrAuthCookieNotFound) return "", fmt.Errorf("%w", ErrAuthCookieNotFound)
} }
var ErrUsernameDoesNotExist = errors.New("username does not exist") var (
ErrUsernameDoesNotExist = errors.New("username does not exist")
ErrUsernameMismatch = errors.New("username in response does not match request username")
)
// authInfo fetches SRP parameters for the account. // authInfo fetches SRP parameters for the account.
func (c *apiClient) authInfo(ctx context.Context, email string, unauthCookie cookie) ( func (c *apiClient) authInfo(ctx context.Context, username string, unauthCookie cookie) (
username, modulusPGPClearSigned, serverEphemeralBase64, saltBase64, srpSessionHex string, modulusPGPClearSigned, serverEphemeralBase64, saltBase64, srpSessionHex string,
version int, err error, version int, err error,
) { ) {
type requestBodySchema struct { type requestBodySchema struct {
Intent string `json:"Intent"` // "Proton" Intent string `json:"Intent"` // "Proton"
Username string `json:"Username"` Username string `json:"Username"` // username without @domain.com
} }
requestBody := requestBodySchema{ requestBody := requestBodySchema{
Intent: "Proton", Intent: "Proton",
Username: email, Username: username,
} }
buffer := bytes.NewBuffer(nil) buffer := bytes.NewBuffer(nil)
encoder := json.NewEncoder(buffer) encoder := json.NewEncoder(buffer)
if err := encoder.Encode(requestBody); err != nil { if err := encoder.Encode(requestBody); err != nil {
return "", "", "", "", "", 0, fmt.Errorf("encoding request body: %w", err) return "", "", "", "", 0, fmt.Errorf("encoding request body: %w", err)
} }
request, err := http.NewRequestWithContext(ctx, http.MethodPost, c.apiURLBase+"/core/v4/auth/info", buffer) request, err := http.NewRequestWithContext(ctx, http.MethodPost, c.apiURLBase+"/core/v4/auth/info", buffer)
if err != nil { if err != nil {
return "", "", "", "", "", 0, fmt.Errorf("creating request: %w", err) return "", "", "", "", 0, fmt.Errorf("creating request: %w", err)
} }
c.setHeaders(request, unauthCookie) c.setHeaders(request, unauthCookie)
response, err := c.httpClient.Do(request) response, err := c.httpClient.Do(request)
if err != nil { if err != nil {
return "", "", "", "", "", 0, err return "", "", "", "", 0, err
} }
defer response.Body.Close() defer response.Body.Close()
responseBody, err := io.ReadAll(response.Body) responseBody, err := io.ReadAll(response.Body)
if err != nil { if err != nil {
return "", "", "", "", "", 0, fmt.Errorf("reading response body: %w", err) return "", "", "", "", 0, fmt.Errorf("reading response body: %w", err)
} else if response.StatusCode != http.StatusOK { } else if response.StatusCode != http.StatusOK {
return "", "", "", "", "", 0, buildError(response.StatusCode, responseBody) return "", "", "", "", 0, buildError(response.StatusCode, responseBody)
} }
var info struct { var info struct {
@@ -351,30 +354,32 @@ func (c *apiClient) authInfo(ctx context.Context, email string, unauthCookie coo
} }
err = json.Unmarshal(responseBody, &info) err = json.Unmarshal(responseBody, &info)
if err != nil { if err != nil {
return "", "", "", "", "", 0, fmt.Errorf("decoding response body: %w", err) return "", "", "", "", 0, fmt.Errorf("decoding response body: %w", err)
} }
const successCode = 1000 const successCode = 1000
switch { switch {
case info.Code != successCode: case info.Code != successCode:
return "", "", "", "", "", 0, fmt.Errorf("%w: expected %d got %d", return "", "", "", "", 0, fmt.Errorf("%w: expected %d got %d",
ErrCodeNotSuccess, successCode, info.Code) ErrCodeNotSuccess, successCode, info.Code)
case info.Modulus == "": case info.Modulus == "":
return "", "", "", "", "", 0, fmt.Errorf("%w: modulus is empty", ErrDataFieldMissing) return "", "", "", "", 0, fmt.Errorf("%w: modulus is empty", ErrDataFieldMissing)
case info.ServerEphemeral == "": case info.ServerEphemeral == "":
return "", "", "", "", "", 0, fmt.Errorf("%w: server ephemeral is empty", ErrDataFieldMissing) return "", "", "", "", 0, fmt.Errorf("%w: server ephemeral is empty", ErrDataFieldMissing)
case info.Salt == "": case info.Salt == "":
return "", "", "", "", "", 0, fmt.Errorf("%w (salt data field is empty)", ErrUsernameDoesNotExist) return "", "", "", "", 0, fmt.Errorf("%w (salt data field is empty)", ErrUsernameDoesNotExist)
case info.SRPSession == "": case info.SRPSession == "":
return "", "", "", "", "", 0, fmt.Errorf("%w: SRP session is empty", ErrDataFieldMissing) return "", "", "", "", 0, fmt.Errorf("%w: SRP session is empty", ErrDataFieldMissing)
case info.Username == "":
return "", "", "", "", "", 0, fmt.Errorf("%w: username is empty", ErrDataFieldMissing) case info.Username != username:
return "", "", "", "", 0, fmt.Errorf("%w: expected %s got %s",
ErrUsernameMismatch, username, info.Username)
case info.Version == nil: case info.Version == nil:
return "", "", "", "", "", 0, fmt.Errorf("%w: version is missing", ErrDataFieldMissing) return "", "", "", "", 0, fmt.Errorf("%w: version is missing", ErrDataFieldMissing)
} }
version = int(*info.Version) //nolint:gosec version = int(*info.Version) //nolint:gosec
return info.Username, info.Modulus, info.ServerEphemeral, info.Salt, return info.Modulus, info.ServerEphemeral, info.Salt,
info.SRPSession, version, nil info.SRPSession, version, nil
} }

View File

@@ -14,8 +14,8 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) (
servers []models.Server, err error, servers []models.Server, err error,
) { ) {
switch { switch {
case u.email == "": case u.username == "":
return nil, fmt.Errorf("%w: email is empty", common.ErrCredentialsMissing) return nil, fmt.Errorf("%w: username is empty", common.ErrCredentialsMissing)
case u.password == "": case u.password == "":
return nil, fmt.Errorf("%w: password is empty", common.ErrCredentialsMissing) return nil, fmt.Errorf("%w: password is empty", common.ErrCredentialsMissing)
} }
@@ -25,7 +25,7 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) (
return nil, fmt.Errorf("creating API client: %w", err) return nil, fmt.Errorf("creating API client: %w", err)
} }
cookie, err := apiClient.authenticate(ctx, u.email, u.password) cookie, err := apiClient.authenticate(ctx, u.username, u.password)
if err != nil { if err != nil {
return nil, fmt.Errorf("authentifying with Proton: %w", err) return nil, fmt.Errorf("authentifying with Proton: %w", err)
} }

View File

@@ -8,15 +8,15 @@ import (
type Updater struct { type Updater struct {
client *http.Client client *http.Client
email string username string
password string password string
warner common.Warner warner common.Warner
} }
func New(client *http.Client, warner common.Warner, email, password string) *Updater { func New(client *http.Client, warner common.Warner, username, password string) *Updater {
return &Updater{ return &Updater{
client: client, client: client,
email: email, username: username,
password: password, password: password,
warner: warner, warner: warner,
} }

View File

@@ -75,7 +75,7 @@ func NewProviders(storage Storage, timeNow func() time.Time,
providers.Privado: privado.New(storage, randSource, ipFetcher, unzipper, updaterWarner, parallelResolver), providers.Privado: privado.New(storage, randSource, ipFetcher, unzipper, updaterWarner, parallelResolver),
providers.PrivateInternetAccess: privateinternetaccess.New(storage, randSource, timeNow, client), providers.PrivateInternetAccess: privateinternetaccess.New(storage, randSource, timeNow, client),
providers.Privatevpn: privatevpn.New(storage, randSource, unzipper, updaterWarner, parallelResolver), providers.Privatevpn: privatevpn.New(storage, randSource, unzipper, updaterWarner, parallelResolver),
providers.Protonvpn: protonvpn.New(storage, randSource, client, updaterWarner, *credentials.ProtonEmail, *credentials.ProtonPassword), providers.Protonvpn: protonvpn.New(storage, randSource, client, updaterWarner, *credentials.ProtonUsername, *credentials.ProtonPassword),
providers.Purevpn: purevpn.New(storage, randSource, ipFetcher, unzipper, updaterWarner, parallelResolver), providers.Purevpn: purevpn.New(storage, randSource, ipFetcher, unzipper, updaterWarner, parallelResolver),
providers.SlickVPN: slickvpn.New(storage, randSource, client, updaterWarner, parallelResolver), providers.SlickVPN: slickvpn.New(storage, randSource, client, updaterWarner, parallelResolver),
providers.Surfshark: surfshark.New(storage, randSource, client, unzipper, updaterWarner, parallelResolver), providers.Surfshark: surfshark.New(storage, randSource, client, unzipper, updaterWarner, parallelResolver),

View File

@@ -26,25 +26,16 @@ func pickConnection(connections []models.Connection,
return connection, ErrNoConnectionToPickFrom return connection, ErrNoConnectionToPickFrom
} }
var targetIP netip.Addr targetIPSet := selection.TargetIP.IsValid() && !selection.TargetIP.IsUnspecified()
switch selection.VPN {
case vpn.OpenVPN:
targetIP = selection.OpenVPN.EndpointIP
case vpn.Wireguard:
targetIP = selection.Wireguard.EndpointIP
default:
panic("unknown VPN type: " + selection.VPN)
}
targetIPSet := targetIP.IsValid() && !targetIP.IsUnspecified()
if targetIPSet && selection.VPN == vpn.Wireguard { if targetIPSet && selection.VPN == vpn.Wireguard {
// we need the right public key // we need the right public key
return getTargetIPConnection(connections, targetIP) return getTargetIPConnection(connections, selection.TargetIP)
} }
connection = pickRandomConnection(connections, randSource) connection = pickRandomConnection(connections, randSource)
if targetIPSet { if targetIPSet {
connection.IP = targetIP connection.IP = selection.TargetIP
} }
return connection, nil return connection, nil

View File

@@ -18,37 +18,15 @@ func New(settings Settings, debugLogger DebugLogger) (
return &authHandler{ return &authHandler{
childHandler: handler, childHandler: handler,
routeToRoles: routeToRoles, routeToRoles: routeToRoles,
unprotectedRoutes: map[string]struct{}{ logger: debugLogger,
http.MethodGet + " /openvpn/actions/restart": {},
http.MethodGet + " /openvpn/portforwarded": {},
http.MethodGet + " /unbound/actions/restart": {},
http.MethodGet + " /updater/restart": {},
http.MethodGet + " /v1/version": {},
http.MethodGet + " /v1/vpn/status": {},
http.MethodPut + " /v1/vpn/status": {},
// GET /v1/vpn/settings is protected by default
// PUT /v1/vpn/settings is protected by default
http.MethodGet + " /v1/openvpn/status": {},
http.MethodPut + " /v1/openvpn/status": {},
http.MethodGet + " /v1/openvpn/portforwarded": {},
// GET /v1/openvpn/settings is protected by default
http.MethodGet + " /v1/dns/status": {},
http.MethodPut + " /v1/dns/status": {},
http.MethodGet + " /v1/updater/status": {},
http.MethodPut + " /v1/updater/status": {},
http.MethodGet + " /v1/publicip/ip": {},
http.MethodGet + " /v1/portforward": {},
},
logger: debugLogger,
} }
}, nil }, nil
} }
type authHandler struct { type authHandler struct {
childHandler http.Handler childHandler http.Handler
routeToRoles map[string][]internalRole routeToRoles map[string][]internalRole
unprotectedRoutes map[string]struct{} // TODO v3.41.0 remove logger DebugLogger
logger DebugLogger
} }
func (h *authHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { func (h *authHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
@@ -66,8 +44,6 @@ func (h *authHandler) ServeHTTP(writer http.ResponseWriter, request *http.Reques
continue continue
} }
h.warnIfUnprotectedByDefault(role, route) // TODO v3.41.0 remove
h.logger.Debugf("access to route %s authorized for role %s", route, role.name) h.logger.Debugf("access to route %s authorized for role %s", route, role.name)
h.childHandler.ServeHTTP(writer, request) h.childHandler.ServeHTTP(writer, request)
return return
@@ -88,26 +64,3 @@ func (h *authHandler) ServeHTTP(writer http.ResponseWriter, request *http.Reques
route, andStrings(allRoleNames)) route, andStrings(allRoleNames))
http.Error(writer, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) http.Error(writer, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
} }
func (h *authHandler) warnIfUnprotectedByDefault(role internalRole, route string) {
// TODO v3.41.0 remove
if role.name != "public" {
// custom role name, allow none authentication to be specified
return
}
_, isNoneChecker := role.checker.(*noneMethod)
if !isNoneChecker {
// not the none authentication method
return
}
_, isUnprotectedByDefault := h.unprotectedRoutes[route]
if !isUnprotectedByDefault {
// route is not unprotected by default, so this is a user decision
return
}
h.logger.Warnf("route %s is unprotected by default, "+
"please set up authentication following the documentation at "+
"https://github.com/qdm12/gluetun-wiki/blob/main/setup/advanced/control-server.md#authentication "+
"since this will become no longer publicly accessible after release v3.40.",
route)
}

View File

@@ -40,27 +40,6 @@ func Test_authHandler_ServeHTTP(t *testing.T) {
statusCode: http.StatusUnauthorized, statusCode: http.StatusUnauthorized,
responseBody: "Unauthorized\n", responseBody: "Unauthorized\n",
}, },
"authorized_unprotected_by_default": {
settings: Settings{
Roles: []Role{
{Name: "public", Auth: AuthNone, Routes: []string{"GET /v1/vpn/status"}},
},
},
makeLogger: func(ctrl *gomock.Controller) *MockDebugLogger {
logger := NewMockDebugLogger(ctrl)
logger.EXPECT().Warnf("route %s is unprotected by default, "+
"please set up authentication following the documentation at "+
"https://github.com/qdm12/gluetun-wiki/blob/main/setup/advanced/control-server.md#authentication "+
"since this will become no longer publicly accessible after release v3.40.",
"GET /v1/vpn/status")
logger.EXPECT().Debugf("access to route %s authorized for role %s",
"GET /v1/vpn/status", "public")
return logger
},
requestMethod: http.MethodGet,
requestPath: "/v1/vpn/status",
statusCode: http.StatusOK,
},
"authorized_none": { "authorized_none": {
settings: Settings{ settings: Settings{
Roles: []Role{ Roles: []Role{

View File

@@ -63,31 +63,6 @@ func (s *Settings) SetDefaultRole(jsonRole string) error {
return nil return nil
} }
func (s *Settings) SetDefaults() {
s.Roles = gosettings.DefaultSlice(s.Roles, []Role{{ // TODO v3.41.0 leave empty
Name: "public",
Auth: "none",
Routes: []string{
http.MethodGet + " /openvpn/actions/restart",
http.MethodGet + " /unbound/actions/restart",
http.MethodGet + " /openvpn/portforwarded",
http.MethodGet + " /updater/restart",
http.MethodGet + " /v1/version",
http.MethodGet + " /v1/vpn/status",
http.MethodPut + " /v1/vpn/status",
http.MethodGet + " /v1/openvpn/status",
http.MethodPut + " /v1/openvpn/status",
http.MethodGet + " /v1/openvpn/portforwarded",
http.MethodGet + " /v1/dns/status",
http.MethodPut + " /v1/dns/status",
http.MethodGet + " /v1/updater/status",
http.MethodPut + " /v1/updater/status",
http.MethodGet + " /v1/publicip/ip",
http.MethodGet + " /v1/portforward",
},
}})
}
func (s Settings) Validate() (err error) { func (s Settings) Validate() (err error) {
for i, role := range s.Roles { for i, role := range s.Roles {
err = role.Validate() err = role.Validate()

View File

@@ -60,7 +60,6 @@ func setupAuthMiddleware(authPath, jsonDefaultRole string, logger Logger) (
if err != nil { if err != nil {
return auth.Settings{}, fmt.Errorf("setting default role: %w", err) return auth.Settings{}, fmt.Errorf("setting default role: %w", err)
} }
authSettings.SetDefaults()
err = authSettings.Validate() err = authSettings.Validate()
if err != nil { if err != nil {
return auth.Settings{}, fmt.Errorf("validating auth settings: %w", err) return auth.Settings{}, fmt.Errorf("validating auth settings: %w", err)

View File

@@ -21,9 +21,6 @@ func (s *Storage) FlushToFile(path string) error {
// flushToFile flushes the merged servers data to the file // flushToFile flushes the merged servers data to the file
// specified by path, as indented JSON. It is not thread-safe. // specified by path, as indented JSON. It is not thread-safe.
func (s *Storage) flushToFile(path string) error { func (s *Storage) flushToFile(path string) error {
if path == "" {
return nil // no file to write to
}
const permission = 0o644 const permission = 0o644
dirPath := filepath.Dir(path) dirPath := filepath.Dir(path)
if err := os.MkdirAll(dirPath, permission); err != nil { if err := os.MkdirAll(dirPath, permission); err != nil {

View File

@@ -8,7 +8,6 @@ import (
"github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/constants/vpn"
) )
func commaJoin(slice []string) string { func commaJoin(slice []string) string {
@@ -149,13 +148,9 @@ func noServerFoundError(selection settings.ServerSelection) (err error) {
messageParts = append(messageParts, "tor only") messageParts = append(messageParts, "tor only")
} }
targetIP := selection.OpenVPN.EndpointIP if selection.TargetIP.IsValid() {
if selection.VPN == vpn.Wireguard {
targetIP = selection.Wireguard.EndpointIP
}
if targetIP.IsValid() {
messageParts = append(messageParts, messageParts = append(messageParts,
"target ip address "+targetIP.String()) "target ip address "+selection.TargetIP.String())
} }
message := "for " + strings.Join(messageParts, "; ") message := "for " + strings.Join(messageParts, "; ")

View File

@@ -1,3 +1,3 @@
package storage package storage
//go:generate mockgen -destination=mocks_test.go -package $GOPACKAGE . Logger //go:generate mockgen -destination=mocks_test.go -package $GOPACKAGE . Infoer

View File

@@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT. // Code generated by MockGen. DO NOT EDIT.
// Source: github.com/qdm12/gluetun/internal/storage (interfaces: Logger) // Source: github.com/qdm12/gluetun/internal/storage (interfaces: Infoer)
// Package storage is a generated GoMock package. // Package storage is a generated GoMock package.
package storage package storage
@@ -10,49 +10,37 @@ import (
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
) )
// MockLogger is a mock of Logger interface. // MockInfoer is a mock of Infoer interface.
type MockLogger struct { type MockInfoer struct {
ctrl *gomock.Controller ctrl *gomock.Controller
recorder *MockLoggerMockRecorder recorder *MockInfoerMockRecorder
} }
// MockLoggerMockRecorder is the mock recorder for MockLogger. // MockInfoerMockRecorder is the mock recorder for MockInfoer.
type MockLoggerMockRecorder struct { type MockInfoerMockRecorder struct {
mock *MockLogger mock *MockInfoer
} }
// NewMockLogger creates a new mock instance. // NewMockInfoer creates a new mock instance.
func NewMockLogger(ctrl *gomock.Controller) *MockLogger { func NewMockInfoer(ctrl *gomock.Controller) *MockInfoer {
mock := &MockLogger{ctrl: ctrl} mock := &MockInfoer{ctrl: ctrl}
mock.recorder = &MockLoggerMockRecorder{mock} mock.recorder = &MockInfoerMockRecorder{mock}
return mock return mock
} }
// EXPECT returns an object that allows the caller to indicate expected use. // EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockLogger) EXPECT() *MockLoggerMockRecorder { func (m *MockInfoer) EXPECT() *MockInfoerMockRecorder {
return m.recorder return m.recorder
} }
// Info mocks base method. // Info mocks base method.
func (m *MockLogger) Info(arg0 string) { func (m *MockInfoer) Info(arg0 string) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
m.ctrl.Call(m, "Info", arg0) m.ctrl.Call(m, "Info", arg0)
} }
// Info indicates an expected call of Info. // Info indicates an expected call of Info.
func (mr *MockLoggerMockRecorder) Info(arg0 interface{}) *gomock.Call { func (mr *MockInfoerMockRecorder) Info(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Info", reflect.TypeOf((*MockLogger)(nil).Info), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Info", reflect.TypeOf((*MockInfoer)(nil).Info), arg0)
}
// Warn mocks base method.
func (m *MockLogger) Warn(arg0 string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Warn", arg0)
}
// Warn indicates an expected call of Warn.
func (mr *MockLoggerMockRecorder) Warn(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Warn", reflect.TypeOf((*MockLogger)(nil).Warn), arg0)
} }

View File

@@ -95,7 +95,7 @@ func Test_extractServersFromBytes(t *testing.T) {
t.Parallel() t.Parallel()
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
logger := NewMockLogger(ctrl) logger := NewMockInfoer(ctrl)
var previousLogCall *gomock.Call var previousLogCall *gomock.Call
for _, logged := range testCase.logged { for _, logged := range testCase.logged {
call := logger.EXPECT().Info(logged) call := logger.EXPECT().Info(logged)

File diff suppressed because it is too large Load Diff

View File

@@ -13,35 +13,30 @@ type Storage struct {
// the embedded JSON file on every call to the // the embedded JSON file on every call to the
// SyncServers method. // SyncServers method.
hardcodedServers models.AllServers hardcodedServers models.AllServers
logger Logger logger Infoer
filepath string filepath string
} }
type Logger interface { type Infoer interface {
Info(s string) Info(s string)
Warn(s string)
} }
// New creates a new storage and reads the servers from the // New creates a new storage and reads the servers from the
// embedded servers file and the file on disk. // embedded servers file and the file on disk.
// Passing an empty filepath disables the reading and writing of // Passing an empty filepath disables writing servers to a file.
// servers. func New(logger Infoer, filepath string) (storage *Storage, err error) {
func New(logger Logger, filepath string) (storage *Storage, err error) {
// A unit test prevents any error from being returned // A unit test prevents any error from being returned
// and ensures all providers are part of the servers returned. // and ensures all providers are part of the servers returned.
hardcodedServers, _ := parseHardcodedServers() hardcodedServers, _ := parseHardcodedServers()
storage = &Storage{ storage = &Storage{
hardcodedServers: hardcodedServers, hardcodedServers: hardcodedServers,
mergedServers: hardcodedServers,
logger: logger, logger: logger,
filepath: filepath, filepath: filepath,
} }
if filepath != "" { if err := storage.syncServers(); err != nil {
if err := storage.syncServers(); err != nil { return nil, err
return nil, err
}
} }
return storage, nil return storage, nil

View File

@@ -46,13 +46,13 @@ func (s *Storage) syncServers() (err error) {
} }
// Eventually write file // Eventually write file
if reflect.DeepEqual(serversOnFile, s.mergedServers) { if s.filepath == "" || reflect.DeepEqual(serversOnFile, s.mergedServers) {
return nil return nil
} }
err = s.flushToFile(s.filepath) err = s.flushToFile(s.filepath)
if err != nil { if err != nil {
s.logger.Warn("failed writing servers to file: " + err.Error()) return fmt.Errorf("writing servers to file: %w", err)
} }
return nil return nil
} }

View File

@@ -101,7 +101,7 @@ type CmdStarter interface {
} }
type HealthChecker interface { type HealthChecker interface {
SetConfig(tlsDialAddrs []string, icmpTargetIPs []netip.Addr, smallCheckType string) SetConfig(tlsDialAddr string, icmpTarget netip.Addr)
Start(ctx context.Context) (runError <-chan error, err error) Start(ctx context.Context) (runError <-chan error, err error)
Stop() error Stop() error
} }

View File

@@ -31,24 +31,19 @@ func (l *Loop) onTunnelUp(ctx, loopCtx context.Context, data tunnelUpData) {
} }
} }
icmpTargetIPs := l.healthSettings.ICMPTargetIPs icmpTarget := l.healthSettings.ICMPTargetIP
if len(icmpTargetIPs) == 1 && icmpTargetIPs[0].IsUnspecified() { if icmpTarget.IsUnspecified() {
icmpTargetIPs = []netip.Addr{data.serverIP} icmpTarget = data.serverIP
} }
l.healthChecker.SetConfig(l.healthSettings.TargetAddresses, icmpTargetIPs, l.healthChecker.SetConfig(l.healthSettings.TargetAddress, icmpTarget)
l.healthSettings.SmallCheckType)
healthErrCh, err := l.healthChecker.Start(ctx) healthErrCh, err := l.healthChecker.Start(ctx)
l.healthServer.SetError(err) l.healthServer.SetError(err)
if err != nil { if err != nil {
if *l.healthSettings.RestartVPN { // Note this restart call must be done in a separate goroutine
// Note this restart call must be done in a separate goroutine // from the VPN loop goroutine.
// from the VPN loop goroutine. l.restartVPN(loopCtx, err)
l.restartVPN(loopCtx, err) return
return
}
l.logger.Warnf("(ignored) healthchecker start failed: %s", err)
l.logger.Info("👉 See https://github.com/qdm12/gluetun-wiki/blob/main/faq/healthcheck.md")
} }
if *l.dnsLooper.GetSettings().ServerEnabled { if *l.dnsLooper.GetSettings().ServerEnabled {
@@ -100,7 +95,7 @@ func (l *Loop) collectHealthErrors(ctx, loopCtx context.Context, healthErrCh <-c
l.restartVPN(loopCtx, healthErr) l.restartVPN(loopCtx, healthErr)
return return
} }
l.logger.Warnf("(ignored) healthcheck failed: %s", healthErr) l.logger.Warnf("healthcheck failed: %s", healthErr)
l.logger.Info("👉 See https://github.com/qdm12/gluetun-wiki/blob/main/faq/healthcheck.md") l.logger.Info("👉 See https://github.com/qdm12/gluetun-wiki/blob/main/faq/healthcheck.md")
} else if previousHealthErr != nil { } else if previousHealthErr != nil {
l.logger.Info("healthcheck passed successfully after previous failure(s)") l.logger.Info("healthcheck passed successfully after previous failure(s)")