chore(settings): use gosettings/sources/env functions

This commit is contained in:
Quentin McGaw
2023-05-30 13:02:10 +00:00
parent 2c30984a10
commit b87b2109b1
19 changed files with 67 additions and 169 deletions

View File

@@ -5,6 +5,7 @@ import (
"net/netip"
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gosettings/sources/env"
)
func (s *Source) readDNS() (dns settings.DNS, err error) {
@@ -13,7 +14,7 @@ func (s *Source) readDNS() (dns settings.DNS, err error) {
return dns, err
}
dns.KeepNameserver, err = envToBoolPtr("DNS_KEEP_NAMESERVER")
dns.KeepNameserver, err = env.BoolPtr("DNS_KEEP_NAMESERVER")
if err != nil {
return dns, fmt.Errorf("environment variable DNS_KEEP_NAMESERVER: %w", err)
}

View File

@@ -6,11 +6,12 @@ import (
"net/netip"
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gosettings/sources/env"
"github.com/qdm12/govalid/binary"
)
func (s *Source) readDNSBlacklist() (blacklist settings.DNSBlacklist, err error) {
blacklist.BlockMalicious, err = envToBoolPtr("BLOCK_MALICIOUS")
blacklist.BlockMalicious, err = env.BoolPtr("BLOCK_MALICIOUS")
if err != nil {
return blacklist, fmt.Errorf("environment variable BLOCK_MALICIOUS: %w", err)
}
@@ -20,7 +21,7 @@ func (s *Source) readDNSBlacklist() (blacklist settings.DNSBlacklist, err error)
return blacklist, err
}
blacklist.BlockAds, err = envToBoolPtr("BLOCK_ADS")
blacklist.BlockAds, err = env.BoolPtr("BLOCK_ADS")
if err != nil {
return blacklist, fmt.Errorf("environment variable BLOCK_ADS: %w", err)
}
@@ -31,7 +32,7 @@ func (s *Source) readDNSBlacklist() (blacklist settings.DNSBlacklist, err error)
return blacklist, err
}
blacklist.AllowedHosts = envToCSV("UNBLOCK") // TODO v4 change name
blacklist.AllowedHosts = env.CSV("UNBLOCK") // TODO v4 change name
return blacklist, nil
}
@@ -52,7 +53,7 @@ var (
func readDoTPrivateAddresses() (ips []netip.Addr,
ipPrefixes []netip.Prefix, err error) {
privateAddresses := envToCSV("DOT_PRIVATE_ADDRESS")
privateAddresses := env.CSV("DOT_PRIVATE_ADDRESS")
if len(privateAddresses) == 0 {
return nil, nil, nil
}

View File

@@ -4,15 +4,16 @@ import (
"fmt"
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gosettings/sources/env"
)
func (s *Source) readDoT() (dot settings.DoT, err error) {
dot.Enabled, err = envToBoolPtr("DOT")
dot.Enabled, err = env.BoolPtr("DOT")
if err != nil {
return dot, fmt.Errorf("environment variable DOT: %w", err)
}
dot.UpdatePeriod, err = envToDurationPtr("DNS_UPDATE_PERIOD")
dot.UpdatePeriod, err = env.DurationPtr("DNS_UPDATE_PERIOD")
if err != nil {
return dot, fmt.Errorf("environment variable DNS_UPDATE_PERIOD: %w", err)
}

View File

@@ -7,34 +7,35 @@ import (
"strconv"
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gosettings/sources/env"
)
func (s *Source) readFirewall() (firewall settings.Firewall, err error) {
vpnInputPortStrings := envToCSV("FIREWALL_VPN_INPUT_PORTS")
vpnInputPortStrings := env.CSV("FIREWALL_VPN_INPUT_PORTS")
firewall.VPNInputPorts, err = stringsToPorts(vpnInputPortStrings)
if err != nil {
return firewall, fmt.Errorf("environment variable FIREWALL_VPN_INPUT_PORTS: %w", err)
}
inputPortStrings := envToCSV("FIREWALL_INPUT_PORTS")
inputPortStrings := env.CSV("FIREWALL_INPUT_PORTS")
firewall.InputPorts, err = stringsToPorts(inputPortStrings)
if err != nil {
return firewall, fmt.Errorf("environment variable FIREWALL_INPUT_PORTS: %w", err)
}
outboundSubnetsKey, _ := s.getEnvWithRetro("FIREWALL_OUTBOUND_SUBNETS", []string{"EXTRA_SUBNETS"})
outboundSubnetStrings := envToCSV(outboundSubnetsKey)
outboundSubnetStrings := env.CSV(outboundSubnetsKey)
firewall.OutboundSubnets, err = stringsToNetipPrefixes(outboundSubnetStrings)
if err != nil {
return firewall, fmt.Errorf("environment variable %s: %w", outboundSubnetsKey, err)
}
firewall.Enabled, err = envToBoolPtr("FIREWALL")
firewall.Enabled, err = env.BoolPtr("FIREWALL")
if err != nil {
return firewall, fmt.Errorf("environment variable FIREWALL: %w", err)
}
firewall.Debug, err = envToBoolPtr("FIREWALL_DEBUG")
firewall.Debug, err = env.BoolPtr("FIREWALL_DEBUG")
if err != nil {
return firewall, fmt.Errorf("environment variable FIREWALL_DEBUG: %w", err)
}

View File

@@ -12,7 +12,7 @@ func (s *Source) ReadHealth() (health settings.Health, err error) {
health.ServerAddress = env.Get("HEALTH_SERVER_ADDRESS")
_, health.TargetAddress = s.getEnvWithRetro("HEALTH_TARGET_ADDRESS", []string{"HEALTH_ADDRESS_TO_PING"})
successWaitPtr, err := envToDurationPtr("HEALTH_SUCCESS_WAIT_DURATION")
successWaitPtr, err := env.DurationPtr("HEALTH_SUCCESS_WAIT_DURATION")
if err != nil {
return health, fmt.Errorf("environment variable HEALTH_SUCCESS_WAIT_DURATION: %w", err)
} else if successWaitPtr != nil {

View File

@@ -3,115 +3,8 @@ package env
import (
"fmt"
"os"
"strconv"
"strings"
"time"
"github.com/qdm12/gosettings/sources/env"
"github.com/qdm12/govalid/binary"
"github.com/qdm12/govalid/integer"
)
func envToCSV(envKey string) (values []string) {
csv := env.Get(envKey)
if csv == "" {
return nil
}
return lowerAndSplit(csv)
}
func envToFloat64(envKey string) (f float64, err error) {
s := env.Get(envKey)
if s == "" {
return 0, nil
}
const bits = 64
return strconv.ParseFloat(s, bits)
}
func envToStringPtr(envKey string, options ...env.Option) (stringPtr *string) {
s := env.Get(envKey, options...)
if s == "" {
return nil
}
return &s
}
func envToBoolPtr(envKey string) (boolPtr *bool, err error) {
s := env.Get(envKey)
value, err := binary.Validate(s)
if err != nil {
return nil, err
}
return value, nil
}
func envToIntPtr(envKey string) (intPtr *int, err error) {
s := env.Get(envKey)
if s == "" {
return nil, nil //nolint:nilnil
}
value, err := strconv.Atoi(s)
if err != nil {
return nil, err
}
return &value, nil
}
func envToUint8Ptr(envKey string) (uint8Ptr *uint8, err error) {
s := env.Get(envKey)
if s == "" {
return nil, nil //nolint:nilnil
}
const min, max = 0, 255
value, err := integer.Validate(s, integer.OptionRange(min, max))
if err != nil {
return nil, err
}
uint8Ptr = new(uint8)
*uint8Ptr = uint8(value)
return uint8Ptr, nil
}
func envToUint16Ptr(envKey string) (uint16Ptr *uint16, err error) {
s := env.Get(envKey)
if s == "" {
return nil, nil //nolint:nilnil
}
const min, max = 0, 65535
value, err := integer.Validate(s, integer.OptionRange(min, max))
if err != nil {
return nil, err
}
uint16Ptr = new(uint16)
*uint16Ptr = uint16(value)
return uint16Ptr, nil
}
func envToDurationPtr(envKey string) (durationPtr *time.Duration, err error) {
s := env.Get(envKey)
if s == "" {
return nil, nil //nolint:nilnil
}
durationPtr = new(time.Duration)
*durationPtr, err = time.ParseDuration(s)
if err != nil {
return nil, err
}
return durationPtr, nil
}
func lowerAndSplit(csv string) (values []string) {
csv = strings.ToLower(csv)
return strings.Split(csv, ",")
}
func unsetEnvKeys(envKeys []string, err error) (newErr error) {
newErr = err
for _, envKey := range envKeys {
@@ -123,6 +16,6 @@ func unsetEnvKeys(envKeys []string, err error) (newErr error) {
return newErr
}
func stringPtr(s string) *string { return &s }
func uint32Ptr(n uint32) *uint32 { return &n }
func boolPtr(b bool) *bool { return &b }
func ptrTo[T any](value T) *T {
return &value
}

View File

@@ -18,7 +18,7 @@ func (s *Source) readHTTPProxy() (httpProxy settings.HTTPProxy, err error) {
return httpProxy, err
}
httpProxy.Stealth, err = envToBoolPtr("HTTPPROXY_STEALTH")
httpProxy.Stealth, err = env.BoolPtr("HTTPPROXY_STEALTH")
if err != nil {
return httpProxy, fmt.Errorf("environment variable HTTPPROXY_STEALTH: %w", err)
}

View File

@@ -25,22 +25,22 @@ func (s *Source) readOpenVPN() (
}
ciphersKey, _ := s.getEnvWithRetro("OPENVPN_CIPHERS", []string{"OPENVPN_CIPHER"})
openVPN.Ciphers = envToCSV(ciphersKey)
openVPN.Ciphers = env.CSV(ciphersKey)
auth := env.Get("OPENVPN_AUTH")
if auth != "" {
openVPN.Auth = &auth
}
openVPN.Cert = envToStringPtr("OPENVPN_CERT", env.ForceLowercase(false))
openVPN.Key = envToStringPtr("OPENVPN_KEY", env.ForceLowercase(false))
openVPN.EncryptedKey = envToStringPtr("OPENVPN_ENCRYPTED_KEY", env.ForceLowercase(false))
openVPN.Cert = env.StringPtr("OPENVPN_CERT", env.ForceLowercase(false))
openVPN.Key = env.StringPtr("OPENVPN_KEY", env.ForceLowercase(false))
openVPN.EncryptedKey = env.StringPtr("OPENVPN_ENCRYPTED_KEY", env.ForceLowercase(false))
openVPN.KeyPassphrase = s.readOpenVPNKeyPassphrase()
openVPN.PIAEncPreset = s.readPIAEncryptionPreset()
openVPN.MSSFix, err = envToUint16Ptr("OPENVPN_MSSFIX")
openVPN.MSSFix, err = env.Uint16Ptr("OPENVPN_MSSFIX")
if err != nil {
return openVPN, fmt.Errorf("environment variable OPENVPN_MSSFIX: %w", err)
}
@@ -53,7 +53,7 @@ func (s *Source) readOpenVPN() (
return openVPN, err
}
openVPN.Verbosity, err = envToIntPtr("OPENVPN_VERBOSITY")
openVPN.Verbosity, err = env.IntPtr("OPENVPN_VERBOSITY")
if err != nil {
return openVPN, fmt.Errorf("environment variable OPENVPN_VERBOSITY: %w", err)
}

View File

@@ -42,9 +42,9 @@ func (s *Source) readOpenVPNProtocol() (tcp *bool, err error) {
case "":
return nil, nil //nolint:nilnil
case constants.UDP:
return boolPtr(false), nil
return ptrTo(false), nil
case constants.TCP:
return boolPtr(true), nil
return ptrTo(true), nil
default:
return nil, fmt.Errorf("environment variable %s: %w: %s",
envKey, ErrOpenVPNProtocolNotValid, protocol)

View File

@@ -14,7 +14,7 @@ func (s *Source) readPortForward() (
"PRIVATE_INTERNET_ACCESS_VPN_PORT_FORWARDING",
"PORT_FORWARDING",
})
portForwarding.Enabled, err = envToBoolPtr(key)
portForwarding.Enabled, err = env.BoolPtr(key)
if err != nil {
return portForwarding, fmt.Errorf("environment variable %s: %w", key, err)
}
@@ -25,7 +25,7 @@ func (s *Source) readPortForward() (
"PORT_FORWARDING_STATUS_FILE",
}, env.ForceLowercase(false))
if value != "" {
portForwarding.Filepath = stringPtr(value)
portForwarding.Filepath = ptrTo(value)
}
return portForwarding, nil

View File

@@ -8,17 +8,17 @@ import (
)
func readPprof() (settings pprof.Settings, err error) {
settings.Enabled, err = envToBoolPtr("PPROF_ENABLED")
settings.Enabled, err = env.BoolPtr("PPROF_ENABLED")
if err != nil {
return settings, fmt.Errorf("environment variable PPROF_ENABLED: %w", err)
}
settings.BlockProfileRate, err = envToIntPtr("PPROF_BLOCK_PROFILE_RATE")
settings.BlockProfileRate, err = env.IntPtr("PPROF_BLOCK_PROFILE_RATE")
if err != nil {
return settings, fmt.Errorf("environment variable PPROF_BLOCK_PROFILE_RATE: %w", err)
}
settings.MutexProfileRate, err = envToIntPtr("PPROF_MUTEX_PROFILE_RATE")
settings.MutexProfileRate, err = env.IntPtr("PPROF_MUTEX_PROFILE_RATE")
if err != nil {
return settings, fmt.Errorf("environment variable PPROF_MUTEX_PROFILE_RATE: %w", err)
}

View File

@@ -35,15 +35,15 @@ func (s *Source) readVPNServiceProvider(vpnType string) (vpnProviderPtr *string)
if value == "" {
if vpnType != vpn.Wireguard && env.Get("OPENVPN_CUSTOM_CONFIG") != "" {
// retro compatibility
return stringPtr(providers.Custom)
return ptrTo(providers.Custom)
}
return nil
}
value = strings.ToLower(value)
if value == "pia" { // retro compatibility
return stringPtr(providers.PrivateInternetAccess)
return ptrTo(providers.PrivateInternetAccess)
}
return stringPtr(value)
return ptrTo(value)
}

View File

@@ -26,28 +26,28 @@ func (s *Source) readServerSelection(vpnProvider, vpnType string) (
}
countriesKey, _ := s.getEnvWithRetro("SERVER_COUNTRIES", []string{"COUNTRY"})
ss.Countries = envToCSV(countriesKey)
ss.Countries = env.CSV(countriesKey)
if vpnProvider == providers.Cyberghost && len(ss.Countries) == 0 {
// Retro-compatibility for Cyberghost using the REGION variable
ss.Countries = envToCSV("REGION")
ss.Countries = env.CSV("REGION")
if len(ss.Countries) > 0 {
s.onRetroActive("REGION", "SERVER_COUNTRIES")
}
}
regionsKey, _ := s.getEnvWithRetro("SERVER_REGIONS", []string{"REGION"})
ss.Regions = envToCSV(regionsKey)
ss.Regions = env.CSV(regionsKey)
citiesKey, _ := s.getEnvWithRetro("SERVER_CITIES", []string{"CITY"})
ss.Cities = envToCSV(citiesKey)
ss.Cities = env.CSV(citiesKey)
ss.ISPs = envToCSV("ISP")
ss.ISPs = env.CSV("ISP")
hostnamesKey, _ := s.getEnvWithRetro("SERVER_HOSTNAMES", []string{"SERVER_HOSTNAME"})
ss.Hostnames = envToCSV(hostnamesKey)
ss.Hostnames = env.CSV(hostnamesKey)
serverNamesKey, _ := s.getEnvWithRetro("SERVER_NAMES", []string{"SERVER_NAME"})
ss.Names = envToCSV(serverNamesKey)
ss.Names = env.CSV(serverNamesKey)
if csv := env.Get("SERVER_NUMBER"); csv != "" {
numbersStrings := strings.Split(csv, ",")
@@ -74,25 +74,25 @@ func (s *Source) readServerSelection(vpnProvider, vpnType string) (
}
// VPNUnlimited and ProtonVPN only
ss.FreeOnly, err = envToBoolPtr("FREE_ONLY")
ss.FreeOnly, err = env.BoolPtr("FREE_ONLY")
if err != nil {
return ss, fmt.Errorf("environment variable FREE_ONLY: %w", err)
}
// VPNSecure only
ss.PremiumOnly, err = envToBoolPtr("PREMIUM_ONLY")
ss.PremiumOnly, err = env.BoolPtr("PREMIUM_ONLY")
if err != nil {
return ss, fmt.Errorf("environment variable PREMIUM_ONLY: %w", err)
}
// VPNUnlimited only
ss.MultiHopOnly, err = envToBoolPtr("MULTIHOP_ONLY")
ss.MultiHopOnly, err = env.BoolPtr("MULTIHOP_ONLY")
if err != nil {
return ss, fmt.Errorf("environment variable MULTIHOP_ONLY: %w", err)
}
// VPNUnlimited only
ss.MultiHopOnly, err = envToBoolPtr("STREAM_ONLY")
ss.MultiHopOnly, err = env.BoolPtr("STREAM_ONLY")
if err != nil {
return ss, fmt.Errorf("environment variable STREAM_ONLY: %w", err)
}
@@ -130,7 +130,7 @@ func (s *Source) readOpenVPNTargetIP() (ip netip.Addr, err error) {
func (s *Source) readOwnedOnly() (ownedOnly *bool, err error) {
envKey, _ := s.getEnvWithRetro("OWNED_ONLY", []string{"OWNED"})
ownedOnly, err = envToBoolPtr(envKey)
ownedOnly, err = env.BoolPtr(envKey)
if err != nil {
return nil, fmt.Errorf("environment variable %s: %w", envKey, err)
}

View File

@@ -9,18 +9,18 @@ import (
)
func (s *Source) readShadowsocks() (shadowsocks settings.Shadowsocks, err error) {
shadowsocks.Enabled, err = envToBoolPtr("SHADOWSOCKS")
shadowsocks.Enabled, err = env.BoolPtr("SHADOWSOCKS")
if err != nil {
return shadowsocks, fmt.Errorf("environment variable SHADOWSOCKS: %w", err)
}
shadowsocks.Address = s.readShadowsocksAddress()
shadowsocks.LogAddresses, err = envToBoolPtr("SHADOWSOCKS_LOG")
shadowsocks.LogAddresses, err = env.BoolPtr("SHADOWSOCKS_LOG")
if err != nil {
return shadowsocks, fmt.Errorf("environment variable SHADOWSOCKS_LOG: %w", err)
}
shadowsocks.CipherName = s.readShadowsocksCipher()
shadowsocks.Password = envToStringPtr("SHADOWSOCKS_PASSWORD", env.ForceLowercase(false))
shadowsocks.Password = env.StringPtr("SHADOWSOCKS_PASSWORD", env.ForceLowercase(false))
return shadowsocks, nil
}

View File

@@ -52,5 +52,5 @@ func (s *Source) readID(key, retroKey string) (
idEnvKey, ErrSystemIDNotValid, idUint64, max)
}
return uint32Ptr(uint32(idUint64)), nil
return ptrTo(uint32(idUint64)), nil
}

View File

@@ -44,13 +44,13 @@ func Test_Reader_readID(t *testing.T) {
keyPrefix: "ID",
keyValue: "1000",
retroKeyPrefix: "RETRO_ID",
id: uint32Ptr(1000),
id: ptrTo(uint32(1000)),
},
"max id": {
keyPrefix: "ID",
keyValue: "4294967295",
retroKeyPrefix: "RETRO_ID",
id: uint32Ptr(4294967295),
id: ptrTo(uint32(4294967295)),
},
"above max id": {
keyPrefix: "ID",

View File

@@ -4,32 +4,33 @@ import (
"fmt"
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gosettings/sources/env"
)
func readUnbound() (unbound settings.Unbound, err error) {
unbound.Providers = envToCSV("DOT_PROVIDERS")
unbound.Providers = env.CSV("DOT_PROVIDERS")
unbound.Caching, err = envToBoolPtr("DOT_CACHING")
unbound.Caching, err = env.BoolPtr("DOT_CACHING")
if err != nil {
return unbound, fmt.Errorf("environment variable DOT_CACHING: %w", err)
}
unbound.IPv6, err = envToBoolPtr("DOT_IPV6")
unbound.IPv6, err = env.BoolPtr("DOT_IPV6")
if err != nil {
return unbound, fmt.Errorf("environment variable DOT_IPV6: %w", err)
}
unbound.VerbosityLevel, err = envToUint8Ptr("DOT_VERBOSITY")
unbound.VerbosityLevel, err = env.Uint8Ptr("DOT_VERBOSITY")
if err != nil {
return unbound, fmt.Errorf("environment variable DOT_VERBOSITY: %w", err)
}
unbound.VerbosityDetailsLevel, err = envToUint8Ptr("DOT_VERBOSITY_DETAILS")
unbound.VerbosityDetailsLevel, err = env.Uint8Ptr("DOT_VERBOSITY_DETAILS")
if err != nil {
return unbound, fmt.Errorf("environment variable DOT_VERBOSITY_DETAILS: %w", err)
}
unbound.ValidationLogLevel, err = envToUint8Ptr("DOT_VALIDATION_LOGLEVEL")
unbound.ValidationLogLevel, err = env.Uint8Ptr("DOT_VALIDATION_LOGLEVEL")
if err != nil {
return unbound, fmt.Errorf("environment variable DOT_VALIDATION_LOGLEVEL: %w", err)
}

View File

@@ -19,12 +19,12 @@ func readUpdater() (updater settings.Updater, err error) {
return updater, err
}
updater.MinRatio, err = envToFloat64("UPDATER_MIN_RATIO")
updater.MinRatio, err = env.Float64("UPDATER_MIN_RATIO")
if err != nil {
return updater, fmt.Errorf("environment variable UPDATER_MIN_RATIO: %w", err)
}
updater.Providers = envToCSV("UPDATER_VPN_SERVICE_PROVIDERS")
updater.Providers = env.CSV("UPDATER_VPN_SERVICE_PROVIDERS")
return updater, nil
}

View File

@@ -13,8 +13,8 @@ func (s *Source) readWireguard() (wireguard settings.Wireguard, err error) {
defer func() {
err = unsetEnvKeys([]string{"WIREGUARD_PRIVATE_KEY", "WIREGUARD_PRESHARED_KEY"}, err)
}()
wireguard.PrivateKey = envToStringPtr("WIREGUARD_PRIVATE_KEY", env.ForceLowercase(false))
wireguard.PreSharedKey = envToStringPtr("WIREGUARD_PRESHARED_KEY", env.ForceLowercase(false))
wireguard.PrivateKey = env.StringPtr("WIREGUARD_PRIVATE_KEY", env.ForceLowercase(false))
wireguard.PreSharedKey = env.StringPtr("WIREGUARD_PRESHARED_KEY", env.ForceLowercase(false))
_, wireguard.Interface = s.getEnvWithRetro("VPN_INTERFACE",
[]string{"WIREGUARD_INTERFACE"}, env.ForceLowercase(false))
wireguard.Implementation = env.Get("WIREGUARD_IMPLEMENTATION")
@@ -22,7 +22,7 @@ func (s *Source) readWireguard() (wireguard settings.Wireguard, err error) {
if err != nil {
return wireguard, err // already wrapped
}
mtuPtr, err := envToUint16Ptr("WIREGUARD_MTU")
mtuPtr, err := env.Uint16Ptr("WIREGUARD_MTU")
if err != nil {
return wireguard, fmt.Errorf("environment variable WIREGUARD_MTU: %w", err)
} else if mtuPtr != nil {