diff --git a/go.mod b/go.mod index d4206427..3423cdad 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,7 @@ require ( github.com/golang/mock v1.6.0 github.com/qdm12/dns v1.11.0 github.com/qdm12/golibs v0.0.0-20210822203818-5c568b0777b6 - github.com/qdm12/gosettings v0.3.0-rc12 + github.com/qdm12/gosettings v0.3.0-rc13 github.com/qdm12/goshutdown v0.3.0 github.com/qdm12/gosplash v0.1.0 github.com/qdm12/gotree v0.2.0 diff --git a/go.sum b/go.sum index 66763cec..8673579c 100644 --- a/go.sum +++ b/go.sum @@ -91,12 +91,8 @@ github.com/qdm12/golibs v0.0.0-20210603202746-e5494e9c2ebb/go.mod h1:15RBzkun0i8 github.com/qdm12/golibs v0.0.0-20210723175634-a75ca7fd74c2/go.mod h1:6aRbg4Z/bTbm9JfxsGXfWKHi7zsOvPfUTK1S5HuAFKg= github.com/qdm12/golibs v0.0.0-20210822203818-5c568b0777b6 h1:bge5AL7cjHJMPz+5IOz5yF01q/l8No6+lIEBieA8gMg= github.com/qdm12/golibs v0.0.0-20210822203818-5c568b0777b6/go.mod h1:6aRbg4Z/bTbm9JfxsGXfWKHi7zsOvPfUTK1S5HuAFKg= -github.com/qdm12/gosettings v0.3.0-rc11 h1:zbH+TiimUdZTzOTMuFqCCC0XFKr3KGC6sZDyuw6y65A= -github.com/qdm12/gosettings v0.3.0-rc11/go.mod h1:+hHzN8lsE63T01t6SruGzc6xkpvfsZFod/ooDs8FWnQ= -github.com/qdm12/gosettings v0.3.0-rc12 h1:HhdVkpFiZfwsAbSiPNpCHk/OsY7Ogl+wrU/y9/1R0y4= -github.com/qdm12/gosettings v0.3.0-rc12/go.mod h1:+hHzN8lsE63T01t6SruGzc6xkpvfsZFod/ooDs8FWnQ= -github.com/qdm12/gosettings v0.3.0-rc9 h1:/Hr+lXjAeZFQ5LiEX+sKgMyWSckmhvTSs9iGo/Ch+q0= -github.com/qdm12/gosettings v0.3.0-rc9/go.mod h1:+hHzN8lsE63T01t6SruGzc6xkpvfsZFod/ooDs8FWnQ= +github.com/qdm12/gosettings v0.3.0-rc13 h1:fag+/hFPBUcNk3a5ifUbwNS2VgXFpxindkl8mQNk76U= +github.com/qdm12/gosettings v0.3.0-rc13/go.mod h1:JRV3opOpHvnKlIA29lKQMdYw1WSMVMfHYLLHPHol5ME= github.com/qdm12/goshutdown v0.3.0 h1:pqBpJkdwlZlfTEx4QHtS8u8CXx6pG0fVo6S1N0MpSEM= github.com/qdm12/goshutdown v0.3.0/go.mod h1:EqZ46No00kCTZ5qzdd3qIzY6ayhMt24QI8Mh8LVQYmM= github.com/qdm12/gosplash v0.1.0 h1:Sfl+zIjFZFP7b0iqf2l5UkmEY97XBnaKkH3FNY6Gf7g= diff --git a/internal/configuration/sources/env/dns.go b/internal/configuration/sources/env/dns.go index 38c3f9ac..c340d7fd 100644 --- a/internal/configuration/sources/env/dns.go +++ b/internal/configuration/sources/env/dns.go @@ -27,19 +27,24 @@ func (s *Source) readDNS() (dns settings.DNS, err error) { } func (s *Source) readDNSServerAddress() (address netip.Addr, err error) { - key, value := s.getEnvWithRetro("DNS_ADDRESS", []string{"DNS_PLAINTEXT_ADDRESS"}) - if value == nil { + const currentKey = "DNS_ADDRESS" + key := firstKeySet(s.env, "DNS_PLAINTEXT_ADDRESS", currentKey) + switch key { + case "": return address, nil + case currentKey: + default: // Retro-compatibility + s.handleDeprecatedKey(key, currentKey) } - address, err = netip.ParseAddr(*value) + address, err = s.env.NetipAddr(key) if err != nil { - return address, fmt.Errorf("environment variable %s: %w", key, err) + return address, err } // TODO remove in v4 if address.Unmap().Compare(netip.AddrFrom4([4]byte{127, 0, 0, 1})) != 0 { - s.warner.Warn(key + " is set to " + *value + + s.warner.Warn(key + " is set to " + address.String() + " so the DNS over TLS (DoT) server will not be used." + " The default value changed to 127.0.0.1 so it uses the internal DoT serves." + " If the DoT server fails to start, the IPv4 address of the first plaintext DNS server" + diff --git a/internal/configuration/sources/env/dnsblacklist.go b/internal/configuration/sources/env/dnsblacklist.go index d125fb79..20f44593 100644 --- a/internal/configuration/sources/env/dnsblacklist.go +++ b/internal/configuration/sources/env/dnsblacklist.go @@ -6,6 +6,7 @@ import ( "net/netip" "github.com/qdm12/gluetun/internal/configuration/settings" + "github.com/qdm12/gosettings/sources/env" ) func (s *Source) readDNSBlacklist() (blacklist settings.DNSBlacklist, err error) { @@ -14,7 +15,8 @@ func (s *Source) readDNSBlacklist() (blacklist settings.DNSBlacklist, err error) return blacklist, err } - blacklist.BlockSurveillance, err = s.readBlockSurveillance() + blacklist.BlockSurveillance, err = s.env.BoolPtr("BLOCK_SURVEILLANCE", + env.RetroKeys("BLOCK_NSA")) if err != nil { return blacklist, err } @@ -35,11 +37,6 @@ func (s *Source) readDNSBlacklist() (blacklist settings.DNSBlacklist, err error) return blacklist, nil } -func (s *Source) readBlockSurveillance() (blocked *bool, err error) { - key, _ := s.getEnvWithRetro("BLOCK_SURVEILLANCE", []string{"BLOCK_NSA"}) - return s.env.BoolPtr(key) -} - var ( ErrPrivateAddressNotValid = errors.New("private address is not a valid IP or CIDR range") ) diff --git a/internal/configuration/sources/env/firewall.go b/internal/configuration/sources/env/firewall.go index 0c5caf1b..bcaba944 100644 --- a/internal/configuration/sources/env/firewall.go +++ b/internal/configuration/sources/env/firewall.go @@ -1,32 +1,25 @@ package env import ( - "errors" - "fmt" - "net/netip" - "strconv" - "github.com/qdm12/gluetun/internal/configuration/settings" + "github.com/qdm12/gosettings/sources/env" ) func (s *Source) readFirewall() (firewall settings.Firewall, err error) { - vpnInputPortStrings := s.env.CSV("FIREWALL_VPN_INPUT_PORTS") - firewall.VPNInputPorts, err = stringsToPorts(vpnInputPortStrings) + firewall.VPNInputPorts, err = s.env.CSVUint16("FIREWALL_VPN_INPUT_PORTS") if err != nil { - return firewall, fmt.Errorf("environment variable FIREWALL_VPN_INPUT_PORTS: %w", err) + return firewall, err } - inputPortStrings := s.env.CSV("FIREWALL_INPUT_PORTS") - firewall.InputPorts, err = stringsToPorts(inputPortStrings) + firewall.InputPorts, err = s.env.CSVUint16("FIREWALL_INPUT_PORTS") if err != nil { - return firewall, fmt.Errorf("environment variable FIREWALL_INPUT_PORTS: %w", err) + return firewall, err } - outboundSubnetsKey, _ := s.getEnvWithRetro("FIREWALL_OUTBOUND_SUBNETS", []string{"EXTRA_SUBNETS"}) - outboundSubnetStrings := s.env.CSV(outboundSubnetsKey) - firewall.OutboundSubnets, err = stringsToNetipPrefixes(outboundSubnetStrings) + firewall.OutboundSubnets, err = s.env.CSVNetipPrefixes("FIREWALL_OUTBOUND_SUBNETS", + env.RetroKeys("EXTRA_SUBNETS")) if err != nil { - return firewall, fmt.Errorf("environment variable %s: %w", outboundSubnetsKey, err) + return firewall, err } firewall.Enabled, err = s.env.BoolPtr("FIREWALL") @@ -41,40 +34,3 @@ func (s *Source) readFirewall() (firewall settings.Firewall, err error) { return firewall, nil } - -var ( - ErrPortParsing = errors.New("cannot parse port") - ErrPortValue = errors.New("port value is not valid") -) - -func stringsToPorts(ss []string) (ports []uint16, err error) { - if len(ss) == 0 { - return nil, nil - } - ports = make([]uint16, len(ss)) - for i, s := range ss { - port, err := strconv.Atoi(s) - if err != nil { - return nil, fmt.Errorf("%w: %s: %s", ErrPortParsing, s, err) - } else if port < 1 || port > 65535 { - return nil, fmt.Errorf("%w: must be between 1 and 65535: %d", - ErrPortValue, port) - } - ports[i] = uint16(port) - } - return ports, nil -} - -func stringsToNetipPrefixes(ss []string) (ipPrefixes []netip.Prefix, err error) { - if len(ss) == 0 { - return nil, nil - } - ipPrefixes = make([]netip.Prefix, len(ss)) - for i, s := range ss { - ipPrefixes[i], err = netip.ParsePrefix(s) - if err != nil { - return nil, fmt.Errorf("parsing IP network %q: %w", s, err) - } - } - return ipPrefixes, nil -} diff --git a/internal/configuration/sources/env/health.go b/internal/configuration/sources/env/health.go index f3c99237..48033373 100644 --- a/internal/configuration/sources/env/health.go +++ b/internal/configuration/sources/env/health.go @@ -1,15 +1,14 @@ package env import ( - "time" - "github.com/qdm12/gluetun/internal/configuration/settings" + "github.com/qdm12/gosettings/sources/env" ) func (s *Source) ReadHealth() (health settings.Health, err error) { health.ServerAddress = s.env.String("HEALTH_SERVER_ADDRESS") - targetAddressEnvKey, _ := s.getEnvWithRetro("HEALTH_TARGET_ADDRESS", []string{"HEALTH_ADDRESS_TO_PING"}) - health.TargetAddress = s.env.String(targetAddressEnvKey) + health.TargetAddress = s.env.String("HEALTH_TARGET_ADDRESS", + env.RetroKeys("HEALTH_ADDRESS_TO_PING")) successWaitPtr, err := s.env.DurationPtr("HEALTH_SUCCESS_WAIT_DURATION") if err != nil { @@ -18,24 +17,19 @@ func (s *Source) ReadHealth() (health settings.Health, err error) { health.SuccessWait = *successWaitPtr } - health.VPN.Initial, err = s.readDurationWithRetro( + health.VPN.Initial, err = s.env.DurationPtr( "HEALTH_VPN_DURATION_INITIAL", - "HEALTH_OPENVPN_DURATION_INITIAL") + env.RetroKeys("HEALTH_OPENVPN_DURATION_INITIAL")) if err != nil { return health, err } - health.VPN.Addition, err = s.readDurationWithRetro( + health.VPN.Addition, err = s.env.DurationPtr( "HEALTH_VPN_DURATION_ADDITION", - "HEALTH_OPENVPN_DURATION_ADDITION") + env.RetroKeys("HEALTH_OPENVPN_DURATION_ADDITION")) if err != nil { return health, err } return health, nil } - -func (s *Source) readDurationWithRetro(envKey, retroEnvKey string) (d *time.Duration, err error) { - envKey, _ = s.getEnvWithRetro(envKey, []string{retroEnvKey}) - return s.env.DurationPtr(envKey) -} diff --git a/internal/configuration/sources/env/helpers.go b/internal/configuration/sources/env/helpers.go index b598c66f..dc877e00 100644 --- a/internal/configuration/sources/env/helpers.go +++ b/internal/configuration/sources/env/helpers.go @@ -3,6 +3,8 @@ package env import ( "fmt" "os" + + "github.com/qdm12/gosettings/sources/env" ) func unsetEnvKeys(envKeys []string, err error) (newErr error) { @@ -19,3 +21,13 @@ func unsetEnvKeys(envKeys []string, err error) (newErr error) { func ptrTo[T any](value T) *T { return &value } + +func firstKeySet(e env.Env, keys ...string) (firstKeySet string) { + for _, key := range keys { + value := e.Get(key) + if value != nil { + return key + } + } + return "" +} diff --git a/internal/configuration/sources/env/httproxy.go b/internal/configuration/sources/env/httproxy.go index 78cd0c66..41d7b27e 100644 --- a/internal/configuration/sources/env/httproxy.go +++ b/internal/configuration/sources/env/httproxy.go @@ -9,15 +9,20 @@ import ( ) func (s *Source) readHTTPProxy() (httpProxy settings.HTTPProxy, err error) { - _, httpProxy.User = s.getEnvWithRetro("HTTPPROXY_USER", - []string{"PROXY_USER", "TINYPROXY_USER"}, env.ForceLowercase(false)) + httpProxy.User = s.env.Get("HTTPPROXY_USER", + env.RetroKeys("PROXY_USER", "TINYPROXY_USER"), + env.ForceLowercase(false)) - _, httpProxy.Password = s.getEnvWithRetro("HTTPPROXY_PASSWORD", - []string{"PROXY_PASSWORD", "TINYPROXY_PASSWORD"}, env.ForceLowercase(false)) + httpProxy.Password = s.env.Get("HTTPPROXY_PASSWORD", + env.RetroKeys("PROXY_PASSWORD", "TINYPROXY_PASSWORD"), + env.ForceLowercase(false)) - httpProxy.ListeningAddress = s.readHTTProxyListeningAddress() + httpProxy.ListeningAddress, err = s.readHTTProxyListeningAddress() + if err != nil { + return httpProxy, err + } - httpProxy.Enabled, err = s.readHTTProxyEnabled() + httpProxy.Enabled, err = s.env.BoolPtr("HTTPPROXY", env.RetroKeys("PROXY", "TINYPROXY")) if err != nil { return httpProxy, err } @@ -35,37 +40,42 @@ func (s *Source) readHTTPProxy() (httpProxy settings.HTTPProxy, err error) { return httpProxy, nil } -func (s *Source) readHTTProxyListeningAddress() (listeningAddress string) { - key, value := s.getEnvWithRetro("HTTPPROXY_LISTENING_ADDRESS", - []string{"PROXY_PORT", "TINYPROXY_PORT", "HTTPPROXY_PORT"}) - if value == nil { - return "" - } else if key == "HTTPPROXY_LISTENING_ADDRESS" { - return *value +func (s *Source) readHTTProxyListeningAddress() (listeningAddress string, err error) { + const currentKey = "HTTPPROXY_LISTENING_ADDRESS" + key := firstKeySet(s.env, "HTTPPROXY_PORT", "TINYPROXY_PORT", "PROXY_PORT", + currentKey) + switch key { + case "": + return "", nil + case currentKey: + return s.env.String(key), nil } - return ":" + *value -} -func (s *Source) readHTTProxyEnabled() (enabled *bool, err error) { - key, _ := s.getEnvWithRetro("HTTPPROXY", - []string{"PROXY", "TINYPROXY"}) - return s.env.BoolPtr(key) + // Retro-compatible keys using a port only + s.handleDeprecatedKey(key, currentKey) + port, err := s.env.Uint16Ptr(key) + if err != nil { + return "", err + } + return fmt.Sprintf(":%d", *port), nil } func (s *Source) readHTTProxyLog() (enabled *bool, err error) { - key, value := s.getEnvWithRetro("HTTPPROXY_LOG", - []string{"PROXY_LOG_LEVEL", "TINYPROXY_LOG"}) - if value == nil { + const currentKey = "HTTPPROXY_LOG" + key := firstKeySet(s.env, "PROXY_LOG", "TINYPROXY_LOG", "HTTPPROXY_LOG") + switch key { + case "": return nil, nil //nolint:nilnil + case currentKey: + return s.env.BoolPtr(key) } - var binaryOptions []binary.Option - if key != "HTTPROXY_LOG" { - retroOption := binary.OptionEnabled("on", "info", "connect", "notice") - binaryOptions = append(binaryOptions, retroOption) - } + // Retro-compatible keys using different boolean verbs + s.handleDeprecatedKey(key, currentKey) + value := s.env.String(key) + retroOption := binary.OptionEnabled("on", "info", "connect", "notice") - enabled, err = binary.Validate(*value, binaryOptions...) + enabled, err = binary.Validate(value, retroOption) if err != nil { return nil, fmt.Errorf("environment variable %s: %w", key, err) } diff --git a/internal/configuration/sources/env/openvpn.go b/internal/configuration/sources/env/openvpn.go index d088296f..abab5ebc 100644 --- a/internal/configuration/sources/env/openvpn.go +++ b/internal/configuration/sources/env/openvpn.go @@ -1,12 +1,10 @@ package env import ( - "fmt" "strings" "github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gosettings/sources/env" - "github.com/qdm12/govalid/binary" ) func (s *Source) readOpenVPN() ( @@ -17,15 +15,12 @@ func (s *Source) readOpenVPN() ( }() openVPN.Version = s.env.String("OPENVPN_VERSION") - _, openVPN.User = s.getEnvWithRetro("OPENVPN_USER", - []string{"USER"}, env.ForceLowercase(false)) - _, openVPN.Password = s.getEnvWithRetro("OPENVPN_PASSWORD", - []string{"PASSWORD"}, env.ForceLowercase(false)) + openVPN.User = s.env.Get("OPENVPN_USER", + env.RetroKeys("USER"), env.ForceLowercase(false)) + openVPN.Password = s.env.Get("OPENVPN_PASSWORD", + env.RetroKeys("PASSWORD"), env.ForceLowercase(false)) openVPN.ConfFile = s.env.Get("OPENVPN_CUSTOM_CONFIG") - - ciphersKey, _ := s.getEnvWithRetro("OPENVPN_CIPHERS", []string{"OPENVPN_CIPHER"}) - openVPN.Ciphers = s.env.CSV(ciphersKey) - + openVPN.Ciphers = s.env.CSV("OPENVPN_CIPHERS", env.RetroKeys("OPENVPN_CIPHER")) openVPN.Auth = s.env.Get("OPENVPN_AUTH") openVPN.Cert = s.env.Get("OPENVPN_CERT", env.ForceLowercase(false)) openVPN.Key = s.env.Get("OPENVPN_KEY", env.ForceLowercase(false)) @@ -39,11 +34,8 @@ func (s *Source) readOpenVPN() ( return openVPN, err } - _, openvpnInterface := s.getEnvWithRetro("VPN_INTERFACE", - []string{"OPENVPN_INTERFACE"}, env.ForceLowercase(false)) - if openvpnInterface != nil { - openVPN.Interface = *openvpnInterface - } + openVPN.Interface = s.env.String("VPN_INTERFACE", + env.RetroKeys("OPENVPN_INTERFACE"), env.ForceLowercase(false)) openVPN.ProcessUser, err = s.readOpenVPNProcessUser() if err != nil { @@ -64,32 +56,22 @@ func (s *Source) readOpenVPN() ( } func (s *Source) readPIAEncryptionPreset() (presetPtr *string) { - _, presetPtr = s.getEnvWithRetro( + return s.env.Get( "PRIVATE_INTERNET_ACCESS_OPENVPN_ENCRYPTION_PRESET", - []string{"PIA_ENCRYPTION", "ENCRYPTION"}) - return presetPtr + env.RetroKeys("ENCRYPTION", "PIA_ENCRYPTION")) } func (s *Source) readOpenVPNProcessUser() (processUser string, err error) { - key, value := s.getEnvWithRetro("OPENVPN_PROCESS_USER", - []string{"OPENVPN_ROOT"}) - if value == nil { - return "", nil - } else if key == "OPENVPN_PROCESS_USER" { - return *value, nil + value, err := s.env.BoolPtr("OPENVPN_ROOT") // Retro-compatibility + if err != nil { + return "", err + } else if value != nil { + if *value { + return "root", nil + } + const defaultNonRootUser = "nonrootuser" + return defaultNonRootUser, nil } - // Retro-compatibility - if *value == "" { - return "", nil - } - root, err := binary.Validate(*value) - if err != nil { - return "", fmt.Errorf("environment variable %s: %w", key, err) - } - if *root { - return "root", nil - } - const defaultNonRootUser = "nonrootuser" - return defaultNonRootUser, nil + return s.env.String("OPENVPN_PROCESS_USER"), nil } diff --git a/internal/configuration/sources/env/openvpnselection.go b/internal/configuration/sources/env/openvpnselection.go index f9611722..8bca2fef 100644 --- a/internal/configuration/sources/env/openvpnselection.go +++ b/internal/configuration/sources/env/openvpnselection.go @@ -19,7 +19,8 @@ func (s *Source) readOpenVPNSelection() ( return selection, err } - selection.CustomPort, err = s.readOpenVPNCustomPort() + selection.CustomPort, err = s.env.Uint16Ptr("VPN_ENDPOINT_PORT", + env.RetroKeys("PORT", "OPENVPN_PORT")) if err != nil { return selection, err } @@ -32,15 +33,18 @@ func (s *Source) readOpenVPNSelection() ( var ErrOpenVPNProtocolNotValid = errors.New("OpenVPN protocol is not valid") func (s *Source) readOpenVPNProtocol() (tcp *bool, err error) { - envKey, protocolPtr := s.getEnvWithRetro("OPENVPN_PROTOCOL", []string{"PROTOCOL"}) - if protocolPtr == nil { - return nil, nil //nolint:nilnil - } - protocol := *protocolPtr - - switch strings.ToLower(protocol) { + const currentKey = "OPENVPN_PROTOCOL" + envKey := firstKeySet(s.env, "PROTOCOL", currentKey) + switch envKey { case "": return nil, nil //nolint:nilnil + case currentKey: + default: // Retro compatibility + s.handleDeprecatedKey(envKey, currentKey) + } + + protocol := s.env.String(envKey) + switch strings.ToLower(protocol) { case constants.UDP: return ptrTo(false), nil case constants.TCP: @@ -50,8 +54,3 @@ func (s *Source) readOpenVPNProtocol() (tcp *bool, err error) { envKey, ErrOpenVPNProtocolNotValid, protocol) } } - -func (s *Source) readOpenVPNCustomPort() (customPort *uint16, err error) { - key, _ := s.getEnvWithRetro("VPN_ENDPOINT_PORT", []string{"PORT", "OPENVPN_PORT"}) - return s.env.Uint16Ptr(key) -} diff --git a/internal/configuration/sources/env/portforward.go b/internal/configuration/sources/env/portforward.go index 76ed0156..e2f6acb6 100644 --- a/internal/configuration/sources/env/portforward.go +++ b/internal/configuration/sources/env/portforward.go @@ -7,21 +7,21 @@ import ( func (s *Source) readPortForward() ( portForwarding settings.PortForwarding, err error) { - key, _ := s.getEnvWithRetro("VPN_PORT_FORWARDING", - []string{ - "PRIVATE_INTERNET_ACCESS_VPN_PORT_FORWARDING", + portForwarding.Enabled, err = s.env.BoolPtr("VPN_PORT_FORWARDING", + env.RetroKeys( "PORT_FORWARDING", - }) - portForwarding.Enabled, err = s.env.BoolPtr(key) + "PRIVATE_INTERNET_ACCESS_VPN_PORT_FORWARDING", + )) if err != nil { return portForwarding, err } - _, portForwarding.Filepath = s.getEnvWithRetro("VPN_PORT_FORWARDING_STATUS_FILE", - []string{ - "PRIVATE_INTERNET_ACCESS_VPN_PORT_FORWARDING_STATUS_FILE", + portForwarding.Filepath = s.env.Get("VPN_PORT_FORWARDING_STATUS_FILE", + env.ForceLowercase(false), + env.RetroKeys( "PORT_FORWARDING_STATUS_FILE", - }, env.ForceLowercase(false)) + "PRIVATE_INTERNET_ACCESS_VPN_PORT_FORWARDING_STATUS_FILE", + )) return portForwarding, nil } diff --git a/internal/configuration/sources/env/provider.go b/internal/configuration/sources/env/provider.go index e234c189..b14d029b 100644 --- a/internal/configuration/sources/env/provider.go +++ b/internal/configuration/sources/env/provider.go @@ -7,6 +7,7 @@ import ( "github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/constants/vpn" + "github.com/qdm12/gosettings/sources/env" ) func (s *Source) readProvider(vpnType string) (provider settings.Provider, err error) { @@ -30,7 +31,7 @@ func (s *Source) readProvider(vpnType string) (provider settings.Provider, err e } func (s *Source) readVPNServiceProvider(vpnType string) (vpnProviderPtr *string) { - _, valuePtr := s.getEnvWithRetro("VPN_SERVICE_PROVIDER", []string{"VPNSP"}) + valuePtr := s.env.Get("VPN_SERVICE_PROVIDER", env.RetroKeys("VPNSP")) if valuePtr == nil { if vpnType != vpn.Wireguard && s.env.Get("OPENVPN_CUSTOM_CONFIG") != nil { // retro compatibility diff --git a/internal/configuration/sources/env/publicip.go b/internal/configuration/sources/env/publicip.go index cb90254f..145e8141 100644 --- a/internal/configuration/sources/env/publicip.go +++ b/internal/configuration/sources/env/publicip.go @@ -11,8 +11,8 @@ func (s *Source) readPublicIP() (publicIP settings.PublicIP, err error) { return publicIP, err } - _, publicIP.IPFilepath = s.getEnvWithRetro("PUBLICIP_FILE", - []string{"IP_STATUS_FILE"}, env.ForceLowercase(false)) + publicIP.IPFilepath = s.env.Get("PUBLICIP_FILE", + env.ForceLowercase(false), env.RetroKeys("IP_STATUS_FILE")) return publicIP, nil } diff --git a/internal/configuration/sources/env/reader.go b/internal/configuration/sources/env/reader.go index 059adf89..a49439f8 100644 --- a/internal/configuration/sources/env/reader.go +++ b/internal/configuration/sources/env/reader.go @@ -8,8 +8,9 @@ import ( ) type Source struct { - env env.Env - warner Warner + env env.Env + warner Warner + handleDeprecatedKey func(deprecatedKey, newKey string) } type Warner interface { @@ -17,9 +18,16 @@ type Warner interface { } func New(warner Warner) *Source { + handleDeprecatedKey := func(deprecatedKey, newKey string) { + warner.Warn( + "You are using the old environment variable " + deprecatedKey + + ", please consider changing it to " + newKey) + } + return &Source{ - env: *env.New(os.Environ()), - warner: warner, + env: *env.New(os.Environ(), handleDeprecatedKey), + warner: warner, + handleDeprecatedKey: handleDeprecatedKey, } } @@ -93,30 +101,3 @@ func (s *Source) Read() (settings settings.Settings, err error) { return settings, nil } - -func (s *Source) onRetroActive(oldKey, newKey string) { - s.warner.Warn( - "You are using the old environment variable " + oldKey + - ", please consider changing it to " + newKey) -} - -// getEnvWithRetro returns the first set environment variable -// key and corresponding value from the environment -// variable keys given. It first goes through the retroKeys -// and end on returning the value corresponding to the currentKey. -// Note retroKeys should be in order from oldest to most -// recent retro-compatibility key. -func (s *Source) getEnvWithRetro(currentKey string, - retroKeys []string, options ...env.Option) (key string, value *string) { - // We check retro-compatibility keys first since - // the current key might be set in the Dockerfile. - for _, key = range retroKeys { - value = s.env.Get(key, options...) - if value != nil { - s.onRetroActive(key, currentKey) - return key, value - } - } - - return currentKey, s.env.Get(currentKey, options...) -} diff --git a/internal/configuration/sources/env/server.go b/internal/configuration/sources/env/server.go index a8a3eea2..31a9cf9e 100644 --- a/internal/configuration/sources/env/server.go +++ b/internal/configuration/sources/env/server.go @@ -16,17 +16,16 @@ func (s *Source) readControlServer() (controlServer settings.ControlServer, err } func (s *Source) readControlServerAddress() (address *string) { - key, value := s.getEnvWithRetro("HTTP_CONTROL_SERVER_ADDRESS", - []string{"CONTROL_SERVER_ADDRESS"}) + const currentKey = "HTTP_CONTROL_SERVER_ADDRESS" + key := firstKeySet(s.env, "CONTROL_SERVER_ADDRESS", currentKey) + if key == currentKey { + return s.env.Get(key) + } + + s.handleDeprecatedKey(key, currentKey) + value := s.env.Get("CONTROL_SERVER_ADDRESS") if value == nil { return nil } - - if key == "HTTP_CONTROL_SERVER_ADDRESS" { - return value - } - - address = new(string) - *address = ":" + *value - return address + return ptrTo(":" + *value) } diff --git a/internal/configuration/sources/env/serverselection.go b/internal/configuration/sources/env/serverselection.go index 7de9bbc3..1f8001a7 100644 --- a/internal/configuration/sources/env/serverselection.go +++ b/internal/configuration/sources/env/serverselection.go @@ -2,72 +2,43 @@ package env import ( "errors" - "fmt" - "net/netip" - "strconv" - "strings" "github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/constants/providers" -) - -var ( - ErrServerNumberNotValid = errors.New("server number is not valid") + "github.com/qdm12/gosettings/sources/env" ) func (s *Source) readServerSelection(vpnProvider, vpnType string) ( ss settings.ServerSelection, err error) { ss.VPN = vpnType - ss.TargetIP, err = s.readOpenVPNTargetIP() + ss.TargetIP, err = s.env.NetipAddr("VPN_ENDPOINT_IP", + env.RetroKeys("OPENVPN_TARGET_IP")) if err != nil { return ss, err } - countriesKey, _ := s.getEnvWithRetro("SERVER_COUNTRIES", []string{"COUNTRY"}) - ss.Countries = s.env.CSV(countriesKey) + ss.Countries = s.env.CSV("SERVER_COUNTRIES", env.RetroKeys("COUNTRY")) if vpnProvider == providers.Cyberghost && len(ss.Countries) == 0 { // Retro-compatibility for Cyberghost using the REGION variable ss.Countries = s.env.CSV("REGION") if len(ss.Countries) > 0 { - s.onRetroActive("REGION", "SERVER_COUNTRIES") + s.handleDeprecatedKey("REGION", "SERVER_COUNTRIES") } } - regionsKey, _ := s.getEnvWithRetro("SERVER_REGIONS", []string{"REGION"}) - ss.Regions = s.env.CSV(regionsKey) - - citiesKey, _ := s.getEnvWithRetro("SERVER_CITIES", []string{"CITY"}) - ss.Cities = s.env.CSV(citiesKey) - + ss.Regions = s.env.CSV("SERVER_REGIONS", env.RetroKeys("REGION")) + ss.Cities = s.env.CSV("SERVER_CITIES", env.RetroKeys("CITY")) ss.ISPs = s.env.CSV("ISP") - - hostnamesKey, _ := s.getEnvWithRetro("SERVER_HOSTNAMES", []string{"SERVER_HOSTNAME"}) - ss.Hostnames = s.env.CSV(hostnamesKey) - - serverNamesKey, _ := s.getEnvWithRetro("SERVER_NAMES", []string{"SERVER_NAME"}) - ss.Names = s.env.CSV(serverNamesKey) - - if csv := s.env.Get("SERVER_NUMBER"); csv != nil { - numbersStrings := strings.Split(*csv, ",") - numbers := make([]uint16, len(numbersStrings)) - for i, numberString := range numbersStrings { - const base, bitSize = 10, 16 - number, err := strconv.ParseInt(numberString, base, bitSize) - if err != nil { - return ss, fmt.Errorf("%w: %s", - ErrServerNumberNotValid, numberString) - } else if number < 0 || number > 65535 { - return ss, fmt.Errorf("%w: %d must be between 0 and 65535", - ErrServerNumberNotValid, number) - } - numbers[i] = uint16(number) - } - ss.Numbers = numbers + ss.Hostnames = s.env.CSV("SERVER_HOSTNAMES", env.RetroKeys("SERVER_HOSTNAME")) + ss.Names = s.env.CSV("SERVER_NAMES", env.RetroKeys("SERVER_NAME")) + ss.Numbers, err = s.env.CSVUint16("SERVER_NUMBER") + if err != nil { + return ss, err } // Mullvad only - ss.OwnedOnly, err = s.readOwnedOnly() + ss.OwnedOnly, err = s.env.BoolPtr("OWNED_ONLY", env.RetroKeys("OWNED")) if err != nil { return ss, err } @@ -112,22 +83,3 @@ func (s *Source) readServerSelection(vpnProvider, vpnType string) ( var ( ErrInvalidIP = errors.New("invalid IP address") ) - -func (s *Source) readOpenVPNTargetIP() (ip netip.Addr, err error) { - envKey, value := s.getEnvWithRetro("VPN_ENDPOINT_IP", []string{"OPENVPN_TARGET_IP"}) - if value == nil { - return ip, nil - } - - ip, err = netip.ParseAddr(*value) - if err != nil { - return ip, fmt.Errorf("environment variable %s: %w", envKey, err) - } - - return ip, nil -} - -func (s *Source) readOwnedOnly() (ownedOnly *bool, err error) { - envKey, _ := s.getEnvWithRetro("OWNED_ONLY", []string{"OWNED"}) - return s.env.BoolPtr(envKey) -} diff --git a/internal/configuration/sources/env/shadowsocks.go b/internal/configuration/sources/env/shadowsocks.go index ad496620..8b612374 100644 --- a/internal/configuration/sources/env/shadowsocks.go +++ b/internal/configuration/sources/env/shadowsocks.go @@ -1,6 +1,8 @@ package env import ( + "fmt" + "github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gosettings/sources/env" ) @@ -11,35 +13,30 @@ func (s *Source) readShadowsocks() (shadowsocks settings.Shadowsocks, err error) return shadowsocks, err } - shadowsocks.Address = s.readShadowsocksAddress() + shadowsocks.Address, err = s.readShadowsocksAddress() + if err != nil { + return shadowsocks, err + } shadowsocks.LogAddresses, err = s.env.BoolPtr("SHADOWSOCKS_LOG") if err != nil { return shadowsocks, err } - shadowsocks.CipherName = s.readShadowsocksCipher() + shadowsocks.CipherName = s.env.String("SHADOWSOCKS_CIPHER", + env.RetroKeys("SHADOWSOCKS_METHOD")) shadowsocks.Password = s.env.Get("SHADOWSOCKS_PASSWORD", env.ForceLowercase(false)) return shadowsocks, nil } -func (s *Source) readShadowsocksAddress() (address *string) { - key, value := s.getEnvWithRetro("SHADOWSOCKS_LISTENING_ADDRESS", - []string{"SHADOWSOCKS_PORT"}) - if value == nil { - return nil +func (s *Source) readShadowsocksAddress() (address *string, err error) { + const currentKey = "SHADOWSOCKS_LISTENING_ADDRESS" + port, err := s.env.Uint16Ptr("SHADOWSOCKS_PORT") // retro-compatibility + if err != nil { + return nil, err + } else if port != nil { + s.handleDeprecatedKey("SHADOWSOCKS_PORT", currentKey) + return ptrTo(fmt.Sprintf(":%d", *port)), nil } - if key == "SHADOWSOCKS_LISTENING_ADDRESS" { - return value - } - - // Retro-compatibility - *value = ":" + *value - return value -} - -func (s *Source) readShadowsocksCipher() (cipher string) { - envKey, _ := s.getEnvWithRetro("SHADOWSOCKS_CIPHER", - []string{"SHADOWSOCKS_METHOD"}) - return s.env.String(envKey) + return s.env.Get(currentKey), nil } diff --git a/internal/configuration/sources/env/system.go b/internal/configuration/sources/env/system.go index 0d678ccf..3c5d9490 100644 --- a/internal/configuration/sources/env/system.go +++ b/internal/configuration/sources/env/system.go @@ -1,26 +1,17 @@ package env import ( - "errors" - "fmt" - "strconv" - "github.com/qdm12/gluetun/internal/configuration/settings" -) - -var ( - ErrSystemPUIDNotValid = errors.New("PUID is not valid") - ErrSystemPGIDNotValid = errors.New("PGID is not valid") - ErrSystemTimezoneNotValid = errors.New("timezone is not valid") + "github.com/qdm12/gosettings/sources/env" ) func (s *Source) readSystem() (system settings.System, err error) { - system.PUID, err = s.readID("PUID", "UID") + system.PUID, err = s.env.Uint32Ptr("PUID", env.RetroKeys("UID")) if err != nil { return system, err } - system.PGID, err = s.readID("PGID", "GID") + system.PGID, err = s.env.Uint32Ptr("PGID", env.RetroKeys("GID")) if err != nil { return system, err } @@ -29,28 +20,3 @@ func (s *Source) readSystem() (system settings.System, err error) { return system, nil } - -var ErrSystemIDNotValid = errors.New("system ID is not valid") - -func (s *Source) readID(key, retroKey string) ( - id *uint32, err error) { - idEnvKey, idStringPtr := s.getEnvWithRetro(key, []string{retroKey}) - if idStringPtr == nil { - return nil, nil //nolint:nilnil - } - idString := *idStringPtr - - const base = 10 - const bitSize = 64 - const max = uint64(^uint32(0)) - idUint64, err := strconv.ParseUint(idString, base, bitSize) - if err != nil { - return nil, fmt.Errorf("environment variable %s: %w: %s", - idEnvKey, ErrSystemIDNotValid, err) - } else if idUint64 > max { - return nil, fmt.Errorf("environment variable %s: %w: %d: must be between 0 and %d", - idEnvKey, ErrSystemIDNotValid, idUint64, max) - } - - return ptrTo(uint32(idUint64)), nil -} diff --git a/internal/configuration/sources/env/system_test.go b/internal/configuration/sources/env/system_test.go deleted file mode 100644 index dcefffff..00000000 --- a/internal/configuration/sources/env/system_test.go +++ /dev/null @@ -1,105 +0,0 @@ -package env - -import ( - "testing" - - "github.com/qdm12/gosettings/sources/env" - "github.com/stretchr/testify/assert" -) - -func Test_Reader_readID(t *testing.T) { - t.Parallel() - - testCases := map[string]struct { - source Source - key string - retroKey string - id *uint32 - errWrapped error - errMessage string - }{ - "empty string": { - source: Source{ - env: *env.New([]string{ - "ID=", - }), - }, - key: "ID", - retroKey: "RETRO_ID", - }, - "invalid string": { - source: Source{ - env: *env.New([]string{ - "ID=invalid", - }), - }, - key: "ID", - retroKey: "RETRO_ID", - errWrapped: ErrSystemIDNotValid, - errMessage: `environment variable ID: ` + - `system ID is not valid: ` + - `strconv.ParseUint: parsing "invalid": invalid syntax`, - }, - "negative number": { - source: Source{ - env: *env.New([]string{ - "ID=-1", - }), - }, - key: "ID", - retroKey: "RETRO_ID", - errWrapped: ErrSystemIDNotValid, - errMessage: `environment variable ID: ` + - `system ID is not valid: ` + - `strconv.ParseUint: parsing "-1": invalid syntax`, - }, - "id 1000": { - source: Source{ - env: *env.New([]string{ - "ID=1000", - }), - }, - key: "ID", - retroKey: "RETRO_ID", - id: ptrTo(uint32(1000)), - }, - "max id": { - source: Source{ - env: *env.New([]string{ - "ID=4294967295", - }), - }, - key: "ID", - retroKey: "RETRO_ID", - id: ptrTo(uint32(4294967295)), - }, - "above max id": { - source: Source{ - env: *env.New([]string{ - "ID=4294967296", - }), - }, - key: "ID", - retroKey: "RETRO_ID", - errWrapped: ErrSystemIDNotValid, - errMessage: `environment variable ID: ` + - `system ID is not valid: 4294967296: must be between 0 and 4294967295`, - }, - } - - for name, testCase := range testCases { - testCase := testCase - t.Run(name, func(t *testing.T) { - t.Parallel() - - id, err := testCase.source.readID(testCase.key, testCase.retroKey) - - assert.ErrorIs(t, err, testCase.errWrapped) - if err != nil { - assert.EqualError(t, err, testCase.errMessage) - } - - assert.Equal(t, testCase.id, id) - }) - } -} diff --git a/internal/configuration/sources/env/wireguard.go b/internal/configuration/sources/env/wireguard.go index 9bc296a4..8a5b2edc 100644 --- a/internal/configuration/sources/env/wireguard.go +++ b/internal/configuration/sources/env/wireguard.go @@ -1,10 +1,6 @@ package env import ( - "fmt" - "net/netip" - "strings" - "github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gosettings/sources/env" ) @@ -15,11 +11,11 @@ func (s *Source) readWireguard() (wireguard settings.Wireguard, err error) { }() wireguard.PrivateKey = s.env.Get("WIREGUARD_PRIVATE_KEY", env.ForceLowercase(false)) wireguard.PreSharedKey = s.env.Get("WIREGUARD_PRESHARED_KEY", env.ForceLowercase(false)) - envKey, _ := s.getEnvWithRetro("VPN_INTERFACE", - []string{"WIREGUARD_INTERFACE"}, env.ForceLowercase(false)) - wireguard.Interface = s.env.String(envKey) + wireguard.Interface = s.env.String("VPN_INTERFACE", + env.RetroKeys("WIREGUARD_INTERFACE"), env.ForceLowercase(false)) wireguard.Implementation = s.env.String("WIREGUARD_IMPLEMENTATION") - wireguard.Addresses, err = s.readWireguardAddresses() + wireguard.Addresses, err = s.env.CSVNetipPrefixes("WIREGUARD_ADDRESSES", + env.RetroKeys("WIREGUARD_ADDRESS")) if err != nil { return wireguard, err // already wrapped } @@ -31,23 +27,3 @@ func (s *Source) readWireguard() (wireguard settings.Wireguard, err error) { } return wireguard, nil } - -func (s *Source) readWireguardAddresses() (addresses []netip.Prefix, err error) { - key, value := s.getEnvWithRetro("WIREGUARD_ADDRESSES", - []string{"WIREGUARD_ADDRESS"}) - if value == nil { - return nil, nil - } - - addressStrings := strings.Split(*value, ",") - addresses = make([]netip.Prefix, len(addressStrings)) - for i, addressString := range addressStrings { - addressString = strings.TrimSpace(addressString) - addresses[i], err = netip.ParsePrefix(addressString) - if err != nil { - return nil, fmt.Errorf("environment variable %s: %w", key, err) - } - } - - return addresses, nil -} diff --git a/internal/configuration/sources/env/wireguardselection.go b/internal/configuration/sources/env/wireguardselection.go index 17f359c4..ebc3a9de 100644 --- a/internal/configuration/sources/env/wireguardselection.go +++ b/internal/configuration/sources/env/wireguardselection.go @@ -1,21 +1,18 @@ package env import ( - "fmt" - "net/netip" - "github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gosettings/sources/env" ) func (s *Source) readWireguardSelection() ( selection settings.WireguardSelection, err error) { - selection.EndpointIP, err = s.readWireguardEndpointIP() + selection.EndpointIP, err = s.env.NetipAddr("VPN_ENDPOINT_IP", env.RetroKeys("WIREGUARD_ENDPOINT_IP")) if err != nil { return selection, err } - selection.EndpointPort, err = s.readWireguardCustomPort() + selection.EndpointPort, err = s.env.Uint16Ptr("VPN_ENDPOINT_PORT", env.RetroKeys("WIREGUARD_ENDPOINT_PORT")) if err != nil { return selection, err } @@ -24,22 +21,3 @@ func (s *Source) readWireguardSelection() ( return selection, nil } - -func (s *Source) readWireguardEndpointIP() (endpointIP netip.Addr, err error) { - key, value := s.getEnvWithRetro("VPN_ENDPOINT_IP", []string{"WIREGUARD_ENDPOINT_IP"}) - if value == nil { - return endpointIP, nil - } - - endpointIP, err = netip.ParseAddr(*value) - if err != nil { - return endpointIP, fmt.Errorf("environment variable %s: %w", key, err) - } - - return endpointIP, nil -} - -func (s *Source) readWireguardCustomPort() (customPort *uint16, err error) { - envKey, _ := s.getEnvWithRetro("VPN_ENDPOINT_PORT", []string{"WIREGUARD_ENDPOINT_PORT"}) - return s.env.Uint16Ptr(envKey) -} diff --git a/internal/configuration/sources/secrets/helpers_test.go b/internal/configuration/sources/secrets/helpers_test.go index 1249bf64..fe2b6a45 100644 --- a/internal/configuration/sources/secrets/helpers_test.go +++ b/internal/configuration/sources/secrets/helpers_test.go @@ -50,7 +50,7 @@ func Test_readSecretFileAsStringPtr(t *testing.T) { source: func(tempDir string) Source { secretFilepath := filepath.Join(tempDir, "secret_file") environ := []string{"SECRET_FILE=" + secretFilepath} - return Source{env: *env.New(environ)} + return Source{env: *env.New(environ, nil)} }, defaultSecretFileName: "default_secret_file", secretPathEnvKey: "SECRET_FILE", diff --git a/internal/configuration/sources/secrets/reader.go b/internal/configuration/sources/secrets/reader.go index 71443cd1..10825aff 100644 --- a/internal/configuration/sources/secrets/reader.go +++ b/internal/configuration/sources/secrets/reader.go @@ -12,8 +12,9 @@ type Source struct { } func New() *Source { + handleDeprecatedKey := (func(deprecatedKey, newKey string))(nil) return &Source{ - env: *env.New(os.Environ()), + env: *env.New(os.Environ(), handleDeprecatedKey), } }