DNS_KEEP_NAMESERVER variable, refers to #188

This commit is contained in:
Quentin McGaw
2020-07-11 23:51:53 +00:00
parent 78b63174ce
commit 8b096af04e
9 changed files with 32 additions and 11 deletions

View File

@@ -76,6 +76,7 @@ ENV VPNSP=pia \
UNBLOCK= \
DNS_UPDATE_PERIOD=24h \
DNS_PLAINTEXT_ADDRESS=1.1.1.1 \
DNS_KEEP_NAMESERVER=off \
# Firewall
FIREWALL=on \
EXTRA_SUBNETS= \

View File

@@ -221,6 +221,7 @@ None of the following values are required.
| `BLOCK_ADS` | `off` | `on`, `off` | Block ads hostnames and IPs with Unbound |
| `UNBLOCK` | |i.e. `domain1.com,x.domain2.co.uk` | Comma separated list of domain names to leave unblocked with Unbound |
| `DNS_PLAINTEXT_ADDRESS` | `1.1.1.1` | Any IP address | IP address to use as DNS resolver if `DOT` is `off` |
| `DNS_KEEP_NAMESERVER` | `off` | `on` or `off` | Keep the nameservers in /etc/resolv.conf untouched, but disabled DNS blocking features |
### Firewall

View File

@@ -17,7 +17,7 @@ type Configurator interface {
DownloadRootKey(uid, gid int) error
MakeUnboundConf(settings settings.DNS, uid, gid int) (err error)
UseDNSInternally(IP net.IP)
UseDNSSystemWide(IP net.IP) error
UseDNSSystemWide(ip net.IP, keepNameserver bool) error
Start(ctx context.Context, logLevel uint8) (stdout io.ReadCloser, waitFn func() error, err error)
WaitForUnbound() (err error)
Version(ctx context.Context) (version string, err error)

View File

@@ -104,8 +104,8 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
// Started successfully
go l.streamMerger.Merge(unboundCtx, stream,
command.MergeName("unbound"), command.MergeColor(constants.ColorUnbound()))
l.conf.UseDNSInternally(net.IP{127, 0, 0, 1}) // use Unbound
if err := l.conf.UseDNSSystemWide(net.IP{127, 0, 0, 1}); err != nil { // use Unbound
l.conf.UseDNSInternally(net.IP{127, 0, 0, 1}) // use Unbound
if err := l.conf.UseDNSSystemWide(net.IP{127, 0, 0, 1}, l.settings.KeepNameserver); err != nil { // use Unbound
l.logger.Error(err)
}
if err := l.conf.WaitForUnbound(); err != nil {
@@ -148,7 +148,7 @@ func (l *looper) fallbackToUnencryptedDNS() {
if targetIP != nil {
l.logger.Info("falling back on plaintext DNS at address %s", targetIP)
l.conf.UseDNSInternally(targetIP)
if err := l.conf.UseDNSSystemWide(targetIP); err != nil {
if err := l.conf.UseDNSSystemWide(targetIP, l.settings.KeepNameserver); err != nil {
l.logger.Error(err)
}
return
@@ -161,7 +161,7 @@ func (l *looper) fallbackToUnencryptedDNS() {
if targetIP.To4() != nil {
l.logger.Info("falling back on plaintext DNS at address %s", targetIP)
l.conf.UseDNSInternally(targetIP)
if err := l.conf.UseDNSSystemWide(targetIP); err != nil {
if err := l.conf.UseDNSSystemWide(targetIP, l.settings.KeepNameserver); err != nil {
l.logger.Error(err)
}
return

View File

@@ -21,7 +21,7 @@ func (c *configurator) UseDNSInternally(ip net.IP) {
}
// UseDNSSystemWide changes the nameserver to use for DNS system wide
func (c *configurator) UseDNSSystemWide(ip net.IP) error {
func (c *configurator) UseDNSSystemWide(ip net.IP, keepNameserver bool) error {
c.logger.Info("using DNS address %s system wide", ip.String())
data, err := c.fileManager.ReadFile(string(constants.ResolvConf))
if err != nil {
@@ -33,10 +33,12 @@ func (c *configurator) UseDNSSystemWide(ip net.IP) error {
lines = nil
}
found := false
for i := range lines {
if strings.HasPrefix(lines[i], "nameserver ") {
lines[i] = "nameserver " + ip.String()
found = true
if !keepNameserver { // default
for i := range lines {
if strings.HasPrefix(lines[i], "nameserver ") {
lines[i] = "nameserver " + ip.String()
found = true
}
}
}
if !found {

View File

@@ -62,7 +62,7 @@ func Test_UseDNSSystemWide(t *testing.T) {
fileManager: fileManager,
logger: logger,
}
err := c.UseDNSSystemWide(net.IP{127, 0, 0, 1})
err := c.UseDNSSystemWide(net.IP{127, 0, 0, 1}, false)
if tc.err != nil {
require.Error(t, err)
assert.Equal(t, tc.err.Error(), err.Error())

View File

@@ -157,3 +157,9 @@ func (r *reader) GetDNSPlaintext() (ip net.IP, err error) {
}
return ip, nil
}
// GetDNSKeepNameserver obtains if the nameserver present in /etc/resolv.conf
// should be kept instead of overridden, from the environment variable DNS_KEEP_NAMESERVER
func (r *reader) GetDNSKeepNameserver() (on bool, err error) {
return r.envParams.GetOnOff("DNS_KEEP_NAMESERVER", libparams.Default("off"))
}

View File

@@ -30,6 +30,7 @@ type Reader interface {
GetDNSOverTLSIPv6() (ipv6 bool, err error)
GetDNSUpdatePeriod() (period time.Duration, err error)
GetDNSPlaintext() (ip net.IP, err error)
GetDNSKeepNameserver() (on bool, err error)
// System
GetUID() (uid int, err error)

View File

@@ -14,6 +14,7 @@ import (
// DNS contains settings to configure Unbound for DNS over TLS operation
type DNS struct {
Enabled bool
KeepNameserver bool
Providers []models.DNSProvider
PlaintextAddress net.IP
AllowedHostnames []string
@@ -61,6 +62,10 @@ func (d *DNS) String() string {
if d.UpdatePeriod > 0 {
update = fmt.Sprintf("every %s", d.UpdatePeriod)
}
keepNameserver := "no"
if d.KeepNameserver {
keepNameserver = "yes"
}
settingsList := []string{
"DNS over TLS settings:",
"DNS over TLS provider:\n |--" + strings.Join(providersStr, "\n |--"),
@@ -75,6 +80,7 @@ func (d *DNS) String() string {
"Validation log level: " + fmt.Sprintf("%d/2", d.ValidationLogLevel),
"IPv6 resolution: " + ipv6,
"Update: " + update,
"Keep nameserver (disabled blocking): " + keepNameserver,
}
return strings.Join(settingsList, "\n |--")
}
@@ -137,6 +143,10 @@ func GetDNSSettings(paramsReader params.Reader) (settings DNS, err error) {
if err != nil {
return settings, err
}
settings.KeepNameserver, err = paramsReader.GetDNSKeepNameserver()
if err != nil {
return settings, err
}
// Consistency check
IPv6Support := false