Compare commits

...

1 Commits

Author SHA1 Message Date
Quentin McGaw
89bd10fc33 Fix DNS_KEEP_NAMESERVER behavior 2021-01-03 16:38:46 +00:00
2 changed files with 34 additions and 24 deletions

View File

@@ -35,23 +35,20 @@ func (c *configurator) UseDNSSystemWide(ip net.IP, keepNameserver bool) error {
_ = file.Close() _ = file.Close()
return err return err
} }
s := strings.TrimSuffix(string(data), "\n") s := strings.TrimSuffix(string(data), "\n")
lines := strings.Split(s, "\n")
if len(lines) == 1 && lines[0] == "" { lines := []string{
lines = nil "nameserver " + ip.String(),
} }
found := false for _, line := range strings.Split(s, "\n") {
if !keepNameserver { // default if line == "" ||
for i := range lines { (!keepNameserver && strings.HasPrefix(line, "nameserver ")) {
if strings.HasPrefix(lines[i], "nameserver ") { continue
lines[i] = "nameserver " + ip.String()
found = true
}
} }
lines = append(lines, line)
} }
if !found {
lines = append(lines, "nameserver "+ip.String())
}
s = strings.Join(lines, "\n") + "\n" s = strings.Join(lines, "\n") + "\n"
_, err = file.WriteString(s) _, err = file.WriteString(s)
if err != nil { if err != nil {

View File

@@ -18,18 +18,22 @@ import (
func Test_UseDNSSystemWide(t *testing.T) { func Test_UseDNSSystemWide(t *testing.T) {
t.Parallel() t.Parallel()
tests := map[string]struct { tests := map[string]struct {
data []byte ip net.IP
writtenData string keepNameserver bool
openErr error data []byte
readErr error writtenData string
writeErr error openErr error
closeErr error readErr error
err error writeErr error
closeErr error
err error
}{ }{
"no data": { "no data": {
ip: net.IP{127, 0, 0, 1},
writtenData: "nameserver 127.0.0.1\n", writtenData: "nameserver 127.0.0.1\n",
}, },
"open error": { "open error": {
ip: net.IP{127, 0, 0, 1},
openErr: fmt.Errorf("error"), openErr: fmt.Errorf("error"),
err: fmt.Errorf("error"), err: fmt.Errorf("error"),
}, },
@@ -38,17 +42,26 @@ func Test_UseDNSSystemWide(t *testing.T) {
err: fmt.Errorf("error"), err: fmt.Errorf("error"),
}, },
"write error": { "write error": {
ip: net.IP{127, 0, 0, 1},
writtenData: "nameserver 127.0.0.1\n", writtenData: "nameserver 127.0.0.1\n",
writeErr: fmt.Errorf("error"), writeErr: fmt.Errorf("error"),
err: fmt.Errorf("error"), err: fmt.Errorf("error"),
}, },
"lines without nameserver": { "lines without nameserver": {
ip: net.IP{127, 0, 0, 1},
data: []byte("abc\ndef\n"), data: []byte("abc\ndef\n"),
writtenData: "abc\ndef\nnameserver 127.0.0.1\n", writtenData: "nameserver 127.0.0.1\nabc\ndef\n",
}, },
"lines with nameserver": { "lines with nameserver": {
ip: net.IP{127, 0, 0, 1},
data: []byte("abc\nnameserver abc def\ndef\n"), data: []byte("abc\nnameserver abc def\ndef\n"),
writtenData: "abc\nnameserver 127.0.0.1\ndef\n", writtenData: "nameserver 127.0.0.1\nabc\ndef\n",
},
"keep nameserver": {
ip: net.IP{127, 0, 0, 1},
keepNameserver: true,
data: []byte("abc\nnameserver abc def\ndef\n"),
writtenData: "nameserver 127.0.0.1\nabc\nnameserver abc def\ndef\n",
}, },
} }
for name, tc := range tests { for name, tc := range tests {
@@ -89,12 +102,12 @@ func Test_UseDNSSystemWide(t *testing.T) {
} }
logger := mock_logging.NewMockLogger(mockCtrl) logger := mock_logging.NewMockLogger(mockCtrl)
logger.EXPECT().Info("using DNS address %s system wide", "127.0.0.1") logger.EXPECT().Info("using DNS address %s system wide", tc.ip.String())
c := &configurator{ c := &configurator{
openFile: openFile, openFile: openFile,
logger: logger, logger: logger,
} }
err := c.UseDNSSystemWide(net.IP{127, 0, 0, 1}, false) err := c.UseDNSSystemWide(tc.ip, tc.keepNameserver)
if tc.err != nil { if tc.err != nil {
require.Error(t, err) require.Error(t, err)
assert.Equal(t, tc.err.Error(), err.Error()) assert.Equal(t, tc.err.Error(), err.Error())