diff --git a/internal/dns/nameserver.go b/internal/dns/nameserver.go index facc615a..8ab3c920 100644 --- a/internal/dns/nameserver.go +++ b/internal/dns/nameserver.go @@ -26,7 +26,7 @@ func (c *configurator) UseDNSInternally(ip net.IP) { func (c *configurator) UseDNSSystemWide(ip net.IP, keepNameserver bool) error { c.logger.Info("using DNS address %s system wide", ip.String()) const filepath = string(constants.ResolvConf) - file, err := c.openFile(filepath, os.O_RDWR|os.O_TRUNC, 0644) + file, err := c.openFile(filepath, os.O_RDONLY, 0) if err != nil { return err } @@ -35,6 +35,9 @@ func (c *configurator) UseDNSSystemWide(ip net.IP, keepNameserver bool) error { _ = file.Close() return err } + if err := file.Close(); err != nil { + return err + } s := strings.TrimSuffix(string(data), "\n") @@ -50,6 +53,11 @@ func (c *configurator) UseDNSSystemWide(ip net.IP, keepNameserver bool) error { } s = strings.Join(lines, "\n") + "\n" + + file, err = c.openFile(filepath, os.O_WRONLY|os.O_TRUNC, 0644) + if err != nil { + return err + } _, err = file.WriteString(s) if err != nil { _ = file.Close() diff --git a/internal/dns/nameserver_test.go b/internal/dns/nameserver_test.go index 7cc6823f..8af3867e 100644 --- a/internal/dns/nameserver_test.go +++ b/internal/dns/nameserver_test.go @@ -17,36 +17,54 @@ import ( func Test_UseDNSSystemWide(t *testing.T) { t.Parallel() + tests := map[string]struct { ip net.IP keepNameserver bool data []byte - writtenData string - openErr error + firstOpenErr error readErr error + firstCloseErr error + secondOpenErr error + writtenData string writeErr error - closeErr error + secondCloseErr error err error }{ "no data": { ip: net.IP{127, 0, 0, 1}, writtenData: "nameserver 127.0.0.1\n", }, - "open error": { - ip: net.IP{127, 0, 0, 1}, - openErr: fmt.Errorf("error"), - err: fmt.Errorf("error"), + "first open error": { + ip: net.IP{127, 0, 0, 1}, + firstOpenErr: fmt.Errorf("error"), + err: fmt.Errorf("error"), }, "read error": { readErr: fmt.Errorf("error"), err: fmt.Errorf("error"), }, + "first close error": { + firstCloseErr: fmt.Errorf("error"), + err: fmt.Errorf("error"), + }, + "second open error": { + ip: net.IP{127, 0, 0, 1}, + secondOpenErr: fmt.Errorf("error"), + err: fmt.Errorf("error"), + }, "write error": { ip: net.IP{127, 0, 0, 1}, writtenData: "nameserver 127.0.0.1\n", writeErr: fmt.Errorf("error"), err: fmt.Errorf("error"), }, + "second close error": { + ip: net.IP{127, 0, 0, 1}, + writtenData: "nameserver 127.0.0.1\n", + secondCloseErr: fmt.Errorf("error"), + err: fmt.Errorf("error"), + }, "lines without nameserver": { ip: net.IP{127, 0, 0, 1}, data: []byte("abc\ndef\n"), @@ -70,9 +88,20 @@ func Test_UseDNSSystemWide(t *testing.T) { t.Parallel() mockCtrl := gomock.NewController(t) - file := mock_os.NewMockFile(mockCtrl) - if tc.openErr == nil { - firstReadCall := file.EXPECT(). + type fileCall struct { + path string + flag int + perm os.FileMode + file os.File + err error + } + + var fileCalls []fileCall + + readOnlyFile := mock_os.NewMockFile(mockCtrl) + + if tc.firstOpenErr == nil { + firstReadCall := readOnlyFile.EXPECT(). Read(gomock.AssignableToTypeOf([]byte{})). DoAndReturn(func(b []byte) (int, error) { copy(b, tc.data) @@ -82,27 +111,55 @@ func Test_UseDNSSystemWide(t *testing.T) { if readErr == nil { readErr = io.EOF } - finalReadCall := file.EXPECT(). + finalReadCall := readOnlyFile.EXPECT(). Read(gomock.AssignableToTypeOf([]byte{})). Return(0, readErr).After(firstReadCall) - if tc.readErr == nil { - writeCall := file.EXPECT().WriteString(tc.writtenData). - Return(0, tc.writeErr).After(finalReadCall) - file.EXPECT().Close().Return(tc.closeErr).After(writeCall) - } else { - file.EXPECT().Close().Return(tc.closeErr).After(finalReadCall) - } + readOnlyFile.EXPECT().Close(). + Return(tc.firstCloseErr). + After(finalReadCall) } + fileCalls = append(fileCalls, fileCall{ + path: string(constants.ResolvConf), + flag: os.O_RDONLY, + perm: 0, + file: readOnlyFile, + err: tc.firstOpenErr, + }) // always return readOnlyFile + + if tc.firstOpenErr == nil && tc.readErr == nil && tc.firstCloseErr == nil { + writeOnlyFile := mock_os.NewMockFile(mockCtrl) + if tc.secondOpenErr == nil { + writeCall := writeOnlyFile.EXPECT(). + WriteString(tc.writtenData). + Return(0, tc.writeErr) + writeOnlyFile.EXPECT(). + Close(). + Return(tc.secondCloseErr). + After(writeCall) + } + fileCalls = append(fileCalls, fileCall{ + path: string(constants.ResolvConf), + flag: os.O_WRONLY | os.O_TRUNC, + perm: os.FileMode(0644), + file: writeOnlyFile, + err: tc.secondOpenErr, + }) + } + + fileCallIndex := 0 openFile := func(name string, flag int, perm os.FileMode) (os.File, error) { - assert.Equal(t, string(constants.ResolvConf), name) - assert.Equal(t, os.O_RDWR|os.O_TRUNC, flag) - assert.Equal(t, os.FileMode(0644), perm) - return file, tc.openErr + fileCall := fileCalls[fileCallIndex] + fileCallIndex++ + assert.Equal(t, fileCall.path, name) + assert.Equal(t, fileCall.flag, flag) + assert.Equal(t, fileCall.perm, perm) + return fileCall.file, fileCall.err } logger := mock_logging.NewMockLogger(mockCtrl) logger.EXPECT().Info("using DNS address %s system wide", tc.ip.String()) + c := &configurator{ openFile: openFile, logger: logger,