diff --git a/internal/dns/nameserver.go b/internal/dns/nameserver.go index 48c2e2a6..facc615a 100644 --- a/internal/dns/nameserver.go +++ b/internal/dns/nameserver.go @@ -35,23 +35,20 @@ func (c *configurator) UseDNSSystemWide(ip net.IP, keepNameserver bool) error { _ = file.Close() return err } + s := strings.TrimSuffix(string(data), "\n") - lines := strings.Split(s, "\n") - if len(lines) == 1 && lines[0] == "" { - lines = nil + + lines := []string{ + "nameserver " + ip.String(), } - found := false - if !keepNameserver { // default - for i := range lines { - if strings.HasPrefix(lines[i], "nameserver ") { - lines[i] = "nameserver " + ip.String() - found = true - } + for _, line := range strings.Split(s, "\n") { + if line == "" || + (!keepNameserver && strings.HasPrefix(line, "nameserver ")) { + continue } + lines = append(lines, line) } - if !found { - lines = append(lines, "nameserver "+ip.String()) - } + s = strings.Join(lines, "\n") + "\n" _, err = file.WriteString(s) if err != nil { diff --git a/internal/dns/nameserver_test.go b/internal/dns/nameserver_test.go index 44c1123f..7cc6823f 100644 --- a/internal/dns/nameserver_test.go +++ b/internal/dns/nameserver_test.go @@ -18,18 +18,22 @@ import ( func Test_UseDNSSystemWide(t *testing.T) { t.Parallel() tests := map[string]struct { - data []byte - writtenData string - openErr error - readErr error - writeErr error - closeErr error - err error + ip net.IP + keepNameserver bool + data []byte + writtenData string + openErr error + readErr error + writeErr error + closeErr error + err error }{ "no data": { + ip: net.IP{127, 0, 0, 1}, writtenData: "nameserver 127.0.0.1\n", }, "open error": { + ip: net.IP{127, 0, 0, 1}, openErr: fmt.Errorf("error"), err: fmt.Errorf("error"), }, @@ -38,17 +42,26 @@ func Test_UseDNSSystemWide(t *testing.T) { err: fmt.Errorf("error"), }, "write error": { + ip: net.IP{127, 0, 0, 1}, writtenData: "nameserver 127.0.0.1\n", writeErr: fmt.Errorf("error"), err: fmt.Errorf("error"), }, "lines without nameserver": { + ip: net.IP{127, 0, 0, 1}, data: []byte("abc\ndef\n"), - writtenData: "abc\ndef\nnameserver 127.0.0.1\n", + writtenData: "nameserver 127.0.0.1\nabc\ndef\n", }, "lines with nameserver": { + ip: net.IP{127, 0, 0, 1}, data: []byte("abc\nnameserver abc def\ndef\n"), - writtenData: "abc\nnameserver 127.0.0.1\ndef\n", + writtenData: "nameserver 127.0.0.1\nabc\ndef\n", + }, + "keep nameserver": { + ip: net.IP{127, 0, 0, 1}, + keepNameserver: true, + data: []byte("abc\nnameserver abc def\ndef\n"), + writtenData: "nameserver 127.0.0.1\nabc\nnameserver abc def\ndef\n", }, } for name, tc := range tests { @@ -89,12 +102,12 @@ func Test_UseDNSSystemWide(t *testing.T) { } logger := mock_logging.NewMockLogger(mockCtrl) - logger.EXPECT().Info("using DNS address %s system wide", "127.0.0.1") + logger.EXPECT().Info("using DNS address %s system wide", tc.ip.String()) c := &configurator{ openFile: openFile, logger: logger, } - err := c.UseDNSSystemWide(net.IP{127, 0, 0, 1}, false) + err := c.UseDNSSystemWide(tc.ip, tc.keepNameserver) if tc.err != nil { require.Error(t, err) assert.Equal(t, tc.err.Error(), err.Error())