DNS_KEEP_NAMESERVER variable, refers to #188

This commit is contained in:
Quentin McGaw
2020-07-11 23:51:53 +00:00
parent 78b63174ce
commit 8b096af04e
9 changed files with 32 additions and 11 deletions

View File

@@ -17,7 +17,7 @@ type Configurator interface {
DownloadRootKey(uid, gid int) error
MakeUnboundConf(settings settings.DNS, uid, gid int) (err error)
UseDNSInternally(IP net.IP)
UseDNSSystemWide(IP net.IP) error
UseDNSSystemWide(ip net.IP, keepNameserver bool) error
Start(ctx context.Context, logLevel uint8) (stdout io.ReadCloser, waitFn func() error, err error)
WaitForUnbound() (err error)
Version(ctx context.Context) (version string, err error)

View File

@@ -104,8 +104,8 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
// Started successfully
go l.streamMerger.Merge(unboundCtx, stream,
command.MergeName("unbound"), command.MergeColor(constants.ColorUnbound()))
l.conf.UseDNSInternally(net.IP{127, 0, 0, 1}) // use Unbound
if err := l.conf.UseDNSSystemWide(net.IP{127, 0, 0, 1}); err != nil { // use Unbound
l.conf.UseDNSInternally(net.IP{127, 0, 0, 1}) // use Unbound
if err := l.conf.UseDNSSystemWide(net.IP{127, 0, 0, 1}, l.settings.KeepNameserver); err != nil { // use Unbound
l.logger.Error(err)
}
if err := l.conf.WaitForUnbound(); err != nil {
@@ -148,7 +148,7 @@ func (l *looper) fallbackToUnencryptedDNS() {
if targetIP != nil {
l.logger.Info("falling back on plaintext DNS at address %s", targetIP)
l.conf.UseDNSInternally(targetIP)
if err := l.conf.UseDNSSystemWide(targetIP); err != nil {
if err := l.conf.UseDNSSystemWide(targetIP, l.settings.KeepNameserver); err != nil {
l.logger.Error(err)
}
return
@@ -161,7 +161,7 @@ func (l *looper) fallbackToUnencryptedDNS() {
if targetIP.To4() != nil {
l.logger.Info("falling back on plaintext DNS at address %s", targetIP)
l.conf.UseDNSInternally(targetIP)
if err := l.conf.UseDNSSystemWide(targetIP); err != nil {
if err := l.conf.UseDNSSystemWide(targetIP, l.settings.KeepNameserver); err != nil {
l.logger.Error(err)
}
return

View File

@@ -21,7 +21,7 @@ func (c *configurator) UseDNSInternally(ip net.IP) {
}
// UseDNSSystemWide changes the nameserver to use for DNS system wide
func (c *configurator) UseDNSSystemWide(ip net.IP) error {
func (c *configurator) UseDNSSystemWide(ip net.IP, keepNameserver bool) error {
c.logger.Info("using DNS address %s system wide", ip.String())
data, err := c.fileManager.ReadFile(string(constants.ResolvConf))
if err != nil {
@@ -33,10 +33,12 @@ func (c *configurator) UseDNSSystemWide(ip net.IP) error {
lines = nil
}
found := false
for i := range lines {
if strings.HasPrefix(lines[i], "nameserver ") {
lines[i] = "nameserver " + ip.String()
found = true
if !keepNameserver { // default
for i := range lines {
if strings.HasPrefix(lines[i], "nameserver ") {
lines[i] = "nameserver " + ip.String()
found = true
}
}
}
if !found {

View File

@@ -62,7 +62,7 @@ func Test_UseDNSSystemWide(t *testing.T) {
fileManager: fileManager,
logger: logger,
}
err := c.UseDNSSystemWide(net.IP{127, 0, 0, 1})
err := c.UseDNSSystemWide(net.IP{127, 0, 0, 1}, false)
if tc.err != nil {
require.Error(t, err)
assert.Equal(t, tc.err.Error(), err.Error())