diff --git a/.golangci.yml b/.golangci.yml index b612403c..92d79a6b 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -35,6 +35,10 @@ issues: path: "openvpnconf.go" linters: - ifshort + - linters: + - lll + source: "^//go:generate " + linters: enable: # - cyclop diff --git a/Dockerfile b/Dockerfile index a357680c..5b06fb4c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -165,6 +165,11 @@ ENV VPNSP=pia \ # Public IP PUBLICIP_FILE="/tmp/gluetun/ip" \ PUBLICIP_PERIOD=12h \ + # Pprof + PPROF_ENABLED=no \ + PPROF_BLOCK_PROFILE_RATE=0 \ + PPROF_MUTEX_PROFILE_RATE=0 \ + PPROF_HTTP_SERVER_ADDRESS=":6060" \ # Extras VERSION_INFORMATION=on \ TZ= \ diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index b5734d9e..78a29e42 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -30,6 +30,7 @@ import ( "github.com/qdm12/gluetun/internal/netlink" "github.com/qdm12/gluetun/internal/openvpn" "github.com/qdm12/gluetun/internal/portforward" + "github.com/qdm12/gluetun/internal/pprof" "github.com/qdm12/gluetun/internal/publicip" "github.com/qdm12/gluetun/internal/routing" "github.com/qdm12/gluetun/internal/server" @@ -190,6 +191,12 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, logger.PatchLevel(*allSettings.Log.Level) + allSettings.Pprof.HTTPServer.Logger = logger + pprofServer, err := pprof.New(allSettings.Pprof) + if err != nil { + return fmt.Errorf("cannot create Pprof server: %w", err) + } + puid, pgid := int(*allSettings.System.PUID), int(*allSettings.System.PGID) const clientTimeout = 15 * time.Second @@ -334,6 +341,12 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, tickersGroupHandler := goshutdown.NewGroupHandler("tickers", defaultGroupOptions...) otherGroupHandler := goshutdown.NewGroupHandler("other", defaultGroupOptions...) + pprofReady := make(chan struct{}) + pprofHandler, pprofCtx, pprofDone := goshutdown.NewGoRoutineHandler("pprof server") + go pprofServer.Run(pprofCtx, pprofReady, pprofDone) + otherGroupHandler.Add(pprofHandler) + <-pprofReady + portForwardLogger := logger.NewChild(logging.Settings{Prefix: "port forwarding: "}) portForwardLooper := portforward.NewLoop(allSettings.VPN.Provider.PortForwarding, httpClient, firewallConf, portForwardLogger) diff --git a/internal/configuration/settings/helpers/merge.go b/internal/configuration/settings/helpers/merge.go index fd1f96b6..3a109d01 100644 --- a/internal/configuration/settings/helpers/merge.go +++ b/internal/configuration/settings/helpers/merge.go @@ -2,6 +2,7 @@ package helpers import ( "net" + "net/http" "time" "github.com/qdm12/golibs/logging" @@ -26,6 +27,13 @@ func MergeWithString(existing, other string) (result string) { return other } +func MergeWithInt(existing, other int) (result int) { + if existing != 0 { + return existing + } + return other +} + func MergeWithStringPtr(existing, other *string) (result *string) { if existing != nil { return existing @@ -37,7 +45,7 @@ func MergeWithStringPtr(existing, other *string) (result *string) { return result } -func MergeWithInt(existing, other *int) (result *int) { +func MergeWithIntPtr(existing, other *int) (result *int) { if existing != nil { return existing } else if other == nil { @@ -99,6 +107,13 @@ func MergeWithLogLevel(existing, other *logging.Level) (result *logging.Level) { return result } +func MergeWithHTTPHandler(existing, other http.Handler) (result http.Handler) { + if existing != nil { + return existing + } + return other +} + func MergeStringSlices(a, b []string) (result []string) { if a == nil && b == nil { return nil diff --git a/internal/configuration/settings/helpers/override.go b/internal/configuration/settings/helpers/override.go index 29b8bbf0..beec8e26 100644 --- a/internal/configuration/settings/helpers/override.go +++ b/internal/configuration/settings/helpers/override.go @@ -2,6 +2,7 @@ package helpers import ( "net" + "net/http" "time" "github.com/qdm12/golibs/logging" @@ -24,6 +25,13 @@ func OverrideWithString(existing, other string) (result string) { return other } +func OverrideWithInt(existing, other int) (result int) { + if other == 0 { + return existing + } + return other +} + func OverrideWithStringPtr(existing, other *string) (result *string) { if other == nil { return existing @@ -33,7 +41,7 @@ func OverrideWithStringPtr(existing, other *string) (result *string) { return result } -func OverrideWithInt(existing, other *int) (result *int) { +func OverrideWithIntPtr(existing, other *int) (result *int) { if other == nil { return existing } @@ -87,6 +95,13 @@ func OverrideWithLogLevel(existing, other *logging.Level) (result *logging.Level return result } +func OverrideWithHTTPHandler(existing, other http.Handler) (result http.Handler) { + if other != nil { + return other + } + return existing +} + func OverrideWithStringSlice(existing, other []string) (result []string) { if other == nil { return existing diff --git a/internal/configuration/settings/openvpn.go b/internal/configuration/settings/openvpn.go index 6c3ad339..f2374a74 100644 --- a/internal/configuration/settings/openvpn.go +++ b/internal/configuration/settings/openvpn.go @@ -199,7 +199,7 @@ func (o *OpenVPN) mergeWith(other OpenVPN) { o.Interface = helpers.MergeWithString(o.Interface, other.Interface) o.Root = helpers.MergeWithBool(o.Root, other.Root) o.ProcUser = helpers.MergeWithString(o.ProcUser, other.ProcUser) - o.Verbosity = helpers.MergeWithInt(o.Verbosity, other.Verbosity) + o.Verbosity = helpers.MergeWithIntPtr(o.Verbosity, other.Verbosity) o.Flags = helpers.MergeStringSlices(o.Flags, other.Flags) } @@ -221,7 +221,7 @@ func (o *OpenVPN) overrideWith(other OpenVPN) { o.Interface = helpers.OverrideWithString(o.Interface, other.Interface) o.Root = helpers.OverrideWithBool(o.Root, other.Root) o.ProcUser = helpers.OverrideWithString(o.ProcUser, other.ProcUser) - o.Verbosity = helpers.OverrideWithInt(o.Verbosity, other.Verbosity) + o.Verbosity = helpers.OverrideWithIntPtr(o.Verbosity, other.Verbosity) o.Flags = helpers.OverrideWithStringSlice(o.Flags, other.Flags) } diff --git a/internal/configuration/settings/settings.go b/internal/configuration/settings/settings.go index 4970a8ca..048a0ee8 100644 --- a/internal/configuration/settings/settings.go +++ b/internal/configuration/settings/settings.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/pprof" "github.com/qdm12/gotree" ) @@ -20,6 +21,7 @@ type Settings struct { Updater Updater Version Version VPN VPN + Pprof pprof.Settings } // Validate validates all the settings and returns an error @@ -38,6 +40,7 @@ func (s *Settings) Validate(allServers models.AllServers) (err error) { "system": s.System.validate, "updater": s.Updater.Validate, "version": s.Version.validate, + // Pprof validation done in pprof constructor "VPN": func() error { return s.VPN.validate(allServers) }, @@ -67,6 +70,7 @@ func (s *Settings) copy() (copied Settings) { Updater: s.Updater.copy(), Version: s.Version.copy(), VPN: s.VPN.copy(), + Pprof: s.Pprof.Copy(), } } @@ -83,6 +87,7 @@ func (s *Settings) MergeWith(other Settings) { s.Updater.mergeWith(other.Updater) s.Version.mergeWith(other.Version) s.VPN.mergeWith(other.VPN) + s.Pprof.MergeWith(other.Pprof) } func (s *Settings) OverrideWith(other Settings, @@ -100,6 +105,7 @@ func (s *Settings) OverrideWith(other Settings, patchedSettings.Updater.overrideWith(other.Updater) patchedSettings.Version.overrideWith(other.Version) patchedSettings.VPN.overrideWith(other.VPN) + patchedSettings.Pprof.MergeWith(other.Pprof) err = patchedSettings.Validate(allServers) if err != nil { return err @@ -121,6 +127,7 @@ func (s *Settings) SetDefaults() { s.Updater.SetDefaults() s.Version.setDefaults() s.VPN.setDefaults() + s.Pprof.SetDefaults() } func (s Settings) String() string { @@ -142,6 +149,7 @@ func (s Settings) toLinesNode() (node *gotree.Node) { node.AppendNode(s.PublicIP.toLinesNode()) node.AppendNode(s.Updater.toLinesNode()) node.AppendNode(s.Version.toLinesNode()) + node.AppendNode(s.Pprof.ToLinesNode()) return node } diff --git a/internal/configuration/sources/env/helpers.go b/internal/configuration/sources/env/helpers.go index 72502709..cdc32c60 100644 --- a/internal/configuration/sources/env/helpers.go +++ b/internal/configuration/sources/env/helpers.go @@ -21,6 +21,14 @@ func envToCSV(envKey string) (values []string) { return lowerAndSplit(csv) } +func envToInt(envKey string) (n int, err error) { + s := os.Getenv(envKey) + if s == "" { + return 0, nil + } + return strconv.Atoi(s) +} + func envToStringPtr(envKey string) (stringPtr *string) { s := os.Getenv(envKey) if s == "" { diff --git a/internal/configuration/sources/env/pprof.go b/internal/configuration/sources/env/pprof.go new file mode 100644 index 00000000..5ab659c9 --- /dev/null +++ b/internal/configuration/sources/env/pprof.go @@ -0,0 +1,29 @@ +package env + +import ( + "fmt" + "os" + + "github.com/qdm12/gluetun/internal/pprof" +) + +func readPprof() (settings pprof.Settings, err error) { + settings.Enabled, err = envToBoolPtr("PPROF_ENABLED") + if err != nil { + return settings, fmt.Errorf("environment variable PPROF_ENABLED: %w", err) + } + + settings.BlockProfileRate, err = envToInt("PPROF_BLOCK_PROFILE_RATE") + if err != nil { + return settings, fmt.Errorf("environment variable PPROF_BLOCK_PROFILE_RATE: %w", err) + } + + settings.MutexProfileRate, err = envToInt("PPROF_MUTEX_PROFILE_RATE") + if err != nil { + return settings, fmt.Errorf("environment variable PPROF_MUTEX_PROFILE_RATE: %w", err) + } + + settings.HTTPServer.Address = os.Getenv("PPROF_HTTP_SERVER_ADDRESS") + + return settings, nil +} diff --git a/internal/configuration/sources/env/reader.go b/internal/configuration/sources/env/reader.go index ccab5223..87f686ca 100644 --- a/internal/configuration/sources/env/reader.go +++ b/internal/configuration/sources/env/reader.go @@ -82,6 +82,11 @@ func (r *Reader) Read() (settings settings.Settings, err error) { return settings, err } + settings.Pprof, err = readPprof() + if err != nil { + return settings, err + } + return settings, nil } diff --git a/internal/httpserver/address.go b/internal/httpserver/address.go new file mode 100644 index 00000000..3f1c2721 --- /dev/null +++ b/internal/httpserver/address.go @@ -0,0 +1,7 @@ +package httpserver + +// GetAddress obtains the address the HTTP server is listening on. +func (s *Server) GetAddress() (address string) { + <-s.addressSet + return s.address +} diff --git a/internal/httpserver/helpers_test.go b/internal/httpserver/helpers_test.go new file mode 100644 index 00000000..304b02e3 --- /dev/null +++ b/internal/httpserver/helpers_test.go @@ -0,0 +1,43 @@ +package httpserver + +import ( + "regexp" + "time" + + gomock "github.com/golang/mock/gomock" +) + +func stringPtr(s string) *string { return &s } +func durationPtr(d time.Duration) *time.Duration { return &d } + +var _ Logger = (*testLogger)(nil) + +type testLogger struct{} + +func (t *testLogger) Info(msg string) {} +func (t *testLogger) Warn(msg string) {} +func (t *testLogger) Error(msg string) {} + +var _ gomock.Matcher = (*regexMatcher)(nil) + +type regexMatcher struct { + regexp *regexp.Regexp +} + +func (r *regexMatcher) Matches(x interface{}) bool { + s, ok := x.(string) + if !ok { + return false + } + return r.regexp.MatchString(s) +} + +func (r *regexMatcher) String() string { + return "regular expression " + r.regexp.String() +} + +func newRegexMatcher(regex string) *regexMatcher { + return ®exMatcher{ + regexp: regexp.MustCompile(regex), + } +} diff --git a/internal/httpserver/logger.go b/internal/httpserver/logger.go new file mode 100644 index 00000000..982d7e32 --- /dev/null +++ b/internal/httpserver/logger.go @@ -0,0 +1,9 @@ +package httpserver + +// Logger is the logger interface accepted by the +// HTTP server. +type Logger interface { + Info(msg string) + Warn(msg string) + Error(msg string) +} diff --git a/internal/httpserver/logger_mock_test.go b/internal/httpserver/logger_mock_test.go new file mode 100644 index 00000000..6099c87c --- /dev/null +++ b/internal/httpserver/logger_mock_test.go @@ -0,0 +1,70 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/qdm12/gluetun/internal/httpserver (interfaces: Logger) + +// Package httpserver is a generated GoMock package. +package httpserver + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockLogger is a mock of Logger interface. +type MockLogger struct { + ctrl *gomock.Controller + recorder *MockLoggerMockRecorder +} + +// MockLoggerMockRecorder is the mock recorder for MockLogger. +type MockLoggerMockRecorder struct { + mock *MockLogger +} + +// NewMockLogger creates a new mock instance. +func NewMockLogger(ctrl *gomock.Controller) *MockLogger { + mock := &MockLogger{ctrl: ctrl} + mock.recorder = &MockLoggerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockLogger) EXPECT() *MockLoggerMockRecorder { + return m.recorder +} + +// Error mocks base method. +func (m *MockLogger) Error(arg0 string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Error", arg0) +} + +// Error indicates an expected call of Error. +func (mr *MockLoggerMockRecorder) Error(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Error", reflect.TypeOf((*MockLogger)(nil).Error), arg0) +} + +// Info mocks base method. +func (m *MockLogger) Info(arg0 string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Info", arg0) +} + +// Info indicates an expected call of Info. +func (mr *MockLoggerMockRecorder) Info(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Info", reflect.TypeOf((*MockLogger)(nil).Info), arg0) +} + +// Warn mocks base method. +func (m *MockLogger) Warn(arg0 string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Warn", arg0) +} + +// Warn indicates an expected call of Warn. +func (mr *MockLoggerMockRecorder) Warn(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Warn", reflect.TypeOf((*MockLogger)(nil).Warn), arg0) +} diff --git a/internal/httpserver/run.go b/internal/httpserver/run.go new file mode 100644 index 00000000..04284033 --- /dev/null +++ b/internal/httpserver/run.go @@ -0,0 +1,66 @@ +package httpserver + +import ( + "context" + "errors" + "net" + "net/http" +) + +// Run runs the HTTP server until ctx is canceled. +// The done channel has an error written to when the HTTP server +// is terminated, and can be nil or not nil. +func (s *Server) Run(ctx context.Context, ready chan<- struct{}, done chan<- struct{}) { + server := http.Server{Addr: s.address, Handler: s.handler} + + crashed := make(chan struct{}) + shutdownDone := make(chan struct{}) + go func() { + defer close(shutdownDone) + select { + case <-ctx.Done(): + case <-crashed: + return + } + + s.logger.Warn(s.name + " http server shutting down: " + ctx.Err().Error()) + shutdownCtx, cancel := context.WithTimeout( + context.Background(), s.shutdownTimeout) + defer cancel() + if err := server.Shutdown(shutdownCtx); err != nil { + s.logger.Error(s.name + " http server failed shutting down within " + + s.shutdownTimeout.String()) + } + }() + + listener, err := net.Listen("tcp", s.address) + if err != nil { + close(s.addressSet) + close(crashed) // stop shutdown goroutine + <-shutdownDone + s.logger.Error(err.Error()) + close(done) + return + } + + s.address = listener.Addr().String() + close(s.addressSet) + + // note: no further write so no need to mutex + s.logger.Info(s.name + " http server listening on " + s.address) + close(ready) + + err = server.Serve(listener) + + if err != nil && !errors.Is(ctx.Err(), context.Canceled) { + // server crashed + close(crashed) // stop shutdown goroutine + } else { + err = nil + } + <-shutdownDone + if err != nil { + s.logger.Error(err.Error()) + } + close(done) +} diff --git a/internal/httpserver/run_test.go b/internal/httpserver/run_test.go new file mode 100644 index 00000000..f6a4055b --- /dev/null +++ b/internal/httpserver/run_test.go @@ -0,0 +1,75 @@ +package httpserver + +import ( + "context" + "regexp" + "testing" + "time" + + gomock "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" +) + +func Test_Server_Run_success(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + + logger := NewMockLogger(ctrl) + logger.EXPECT().Info(newRegexMatcher("^test http server listening on 127.0.0.1:[1-9][0-9]{0,4}$")) + logger.EXPECT().Warn("test http server shutting down: context canceled") + const shutdownTimeout = 10 * time.Second + + server := &Server{ + name: "test", + address: "127.0.0.1:0", + addressSet: make(chan struct{}), + logger: logger, + shutdownTimeout: shutdownTimeout, + } + + ctx, cancel := context.WithCancel(context.Background()) + ready := make(chan struct{}) + done := make(chan struct{}) + + go server.Run(ctx, ready, done) + + addressRegex := regexp.MustCompile(`^127.0.0.1:[1-9][0-9]{0,4}$`) + address := server.GetAddress() + assert.Regexp(t, addressRegex, address) + address = server.GetAddress() + assert.Regexp(t, addressRegex, address) + + <-ready + + cancel() + _, ok := <-done + assert.False(t, ok) +} + +func Test_Server_Run_failure(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + logger := NewMockLogger(ctrl) + logger.EXPECT().Error("listen tcp: address -1: invalid port") + + server := &Server{ + name: "test", + address: "127.0.0.1:-1", + addressSet: make(chan struct{}), + logger: logger, + } + + ready := make(chan struct{}) + done := make(chan struct{}) + + go server.Run(context.Background(), ready, done) + + select { + case <-ready: + t.Fatal("server should not be ready") + case _, ok := <-done: + assert.False(t, ok) + } +} diff --git a/internal/httpserver/server.go b/internal/httpserver/server.go new file mode 100644 index 00000000..0c0e7b2a --- /dev/null +++ b/internal/httpserver/server.go @@ -0,0 +1,57 @@ +// Package httpserver implements an HTTP server. +package httpserver + +import ( + "context" + "fmt" + "net/http" + "time" +) + +var _ Interface = (*Server)(nil) + +// Interface is the HTTP server composite interface. +type Interface interface { + Runner + AddressGetter +} + +// Runner is the interface for an HTTP server with a Run method. +type Runner interface { + Run(ctx context.Context, ready chan<- struct{}, done chan<- struct{}) +} + +// AddressGetter obtains the address the HTTP server is listening on. +type AddressGetter interface { + GetAddress() (address string) +} + +// Server is an HTTP server implementation, which uses +// the HTTP handler provided. +type Server struct { + name string + address string + addressSet chan struct{} + handler http.Handler + logger Logger + shutdownTimeout time.Duration +} + +// New creates a new HTTP server with the given settings. +// It returns an error if one of the settings is not valid. +func New(settings Settings) (s *Server, err error) { + settings.SetDefaults() + + if err = settings.Validate(); err != nil { + return nil, fmt.Errorf("http server settings validation failed: %w", err) + } + + return &Server{ + name: *settings.Name, + address: settings.Address, + addressSet: make(chan struct{}), + handler: settings.Handler, + logger: settings.Logger, + shutdownTimeout: *settings.ShutdownTimeout, + }, nil +} diff --git a/internal/httpserver/server_test.go b/internal/httpserver/server_test.go new file mode 100644 index 00000000..2a097806 --- /dev/null +++ b/internal/httpserver/server_test.go @@ -0,0 +1,67 @@ +package httpserver + +import ( + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +//go:generate mockgen -destination=logger_mock_test.go -package $GOPACKAGE . Logger + +func Test_New(t *testing.T) { + t.Parallel() + + someHandler := http.NewServeMux() + someLogger := &testLogger{} + + testCases := map[string]struct { + settings Settings + expected *Server + errWrapped error + errMessage string + }{ + "empty settings": { + errWrapped: ErrHandlerIsNotSet, + errMessage: "http server settings validation failed: HTTP handler cannot be left unset", + }, + "filled settings": { + settings: Settings{ + Name: stringPtr("name"), + Address: ":8001", + Handler: someHandler, + Logger: someLogger, + ShutdownTimeout: durationPtr(time.Second), + }, + expected: &Server{ + name: "name", + address: ":8001", + handler: someHandler, + logger: someLogger, + shutdownTimeout: time.Second, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + server, err := New(testCase.settings) + assert.ErrorIs(t, err, testCase.errWrapped) + if testCase.errWrapped != nil { + require.EqualError(t, err, testCase.errMessage) + } + + if server != nil { + assert.NotNil(t, server.addressSet) + server.addressSet = nil + } + + assert.Equal(t, testCase.expected, server) + }) + } +} diff --git a/internal/httpserver/settings.go b/internal/httpserver/settings.go new file mode 100644 index 00000000..f84ad439 --- /dev/null +++ b/internal/httpserver/settings.go @@ -0,0 +1,111 @@ +package httpserver + +import ( + "errors" + "fmt" + "net/http" + "os" + "strings" + "time" + + "github.com/qdm12/gluetun/internal/configuration/settings/helpers" + "github.com/qdm12/gotree" + "github.com/qdm12/govalid/address" +) + +type Settings struct { + // Name is the server name to use in logs. + // It defaults to the empty string. + Name *string + // Address is the server listening address. + // It defaults to :8000. + Address string + // Handler is the HTTP Handler to use. + // It must be set and cannot be left to nil. + Handler http.Handler + // Logger is the logger to use. + // It must be set and cannot be left to nil. + Logger Logger + // ShutdownTimeout is the shutdown timeout duration + // of the HTTP server. It defaults to 3 seconds. + ShutdownTimeout *time.Duration +} + +func (s *Settings) SetDefaults() { + s.Name = helpers.DefaultStringPtr(s.Name, "") + s.Address = helpers.DefaultString(s.Address, ":8000") + const defaultShutdownTimeout = 3 * time.Second + s.ShutdownTimeout = helpers.DefaultDuration(s.ShutdownTimeout, defaultShutdownTimeout) +} + +func (s Settings) Copy() Settings { + return Settings{ + Name: helpers.CopyStringPtr(s.Name), + Address: s.Address, + Handler: s.Handler, + Logger: s.Logger, + ShutdownTimeout: helpers.CopyDurationPtr(s.ShutdownTimeout), + } +} + +func (s *Settings) MergeWith(other Settings) { + s.Name = helpers.MergeWithStringPtr(s.Name, other.Name) + s.Address = helpers.MergeWithString(s.Address, other.Address) + s.Handler = helpers.MergeWithHTTPHandler(s.Handler, other.Handler) + if s.Logger == nil { + s.Logger = other.Logger + } + s.ShutdownTimeout = helpers.MergeWithDuration(s.ShutdownTimeout, other.ShutdownTimeout) +} + +func (s *Settings) OverrideWith(other Settings) { + s.Name = helpers.OverrideWithStringPtr(s.Name, other.Name) + s.Address = helpers.OverrideWithString(s.Address, other.Address) + s.Handler = helpers.OverrideWithHTTPHandler(s.Handler, other.Handler) + if other.Logger != nil { + s.Logger = other.Logger + } + s.ShutdownTimeout = helpers.OverrideWithDuration(s.ShutdownTimeout, other.ShutdownTimeout) +} + +var ( + ErrHandlerIsNotSet = errors.New("HTTP handler cannot be left unset") + ErrLoggerIsNotSet = errors.New("logger cannot be left unset") + ErrShutdownTimeoutTooSmall = errors.New("shutdown timeout is too small") +) + +func (s Settings) Validate() (err error) { + uid := os.Getuid() + _, err = address.Validate(s.Address, address.OptionListening(uid)) + if err != nil { + return err + } + + if s.Handler == nil { + return ErrHandlerIsNotSet + } + + if s.Logger == nil { + return ErrLoggerIsNotSet + } + + const minShutdownTimeout = 5 * time.Millisecond + if *s.ShutdownTimeout < minShutdownTimeout { + return fmt.Errorf("%w: %s must be at least %s", + ErrShutdownTimeoutTooSmall, + *s.ShutdownTimeout, minShutdownTimeout) + } + + return nil +} + +func (s Settings) ToLinesNode() (node *gotree.Node) { + node = gotree.New("%s HTTP server settings:", strings.Title(*s.Name)) + node.Appendf("Listening address: %s", s.Address) + node.Appendf("Shutdown timeout: %s", *s.ShutdownTimeout) + return node +} + +func (s Settings) String() string { + return s.ToLinesNode().String() +} diff --git a/internal/httpserver/settings_test.go b/internal/httpserver/settings_test.go new file mode 100644 index 00000000..2fae6a14 --- /dev/null +++ b/internal/httpserver/settings_test.go @@ -0,0 +1,330 @@ +package httpserver + +import ( + "net/http" + "testing" + "time" + + "github.com/qdm12/govalid/address" + "github.com/stretchr/testify/assert" +) + +func Test_Settings_SetDefaults(t *testing.T) { + t.Parallel() + + const defaultTimeout = 3 * time.Second + + testCases := map[string]struct { + settings Settings + expected Settings + }{ + "empty settings": { + settings: Settings{}, + expected: Settings{ + Name: stringPtr(""), + Address: ":8000", + ShutdownTimeout: durationPtr(defaultTimeout), + }, + }, + "filled settings": { + settings: Settings{ + Name: stringPtr("name"), + Address: ":8001", + ShutdownTimeout: durationPtr(time.Second), + }, + expected: Settings{ + Name: stringPtr("name"), + Address: ":8001", + ShutdownTimeout: durationPtr(time.Second), + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + testCase.settings.SetDefaults() + + assert.Equal(t, testCase.expected, testCase.settings) + }) + } +} + +func Test_Settings_Copy(t *testing.T) { + t.Parallel() + + someHandler := http.NewServeMux() + someLogger := &testLogger{} + + testCases := map[string]struct { + settings Settings + expected Settings + }{ + "empty settings": {}, + "filled settings": { + settings: Settings{ + Name: stringPtr("name"), + Address: ":8001", + Handler: someHandler, + Logger: someLogger, + ShutdownTimeout: durationPtr(time.Second), + }, + expected: Settings{ + Name: stringPtr("name"), + Address: ":8001", + Handler: someHandler, + Logger: someLogger, + ShutdownTimeout: durationPtr(time.Second), + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + copied := testCase.settings.Copy() + + assert.Equal(t, testCase.expected, copied) + }) + } +} + +func Test_Settings_MergeWith(t *testing.T) { + t.Parallel() + + someHandler := http.NewServeMux() + someLogger := &testLogger{} + + testCases := map[string]struct { + settings Settings + other Settings + expected Settings + }{ + "merge empty with empty": {}, + "merge empty with filled": { + other: Settings{ + Name: stringPtr("name"), + Address: ":8001", + Handler: someHandler, + Logger: someLogger, + ShutdownTimeout: durationPtr(time.Second), + }, + expected: Settings{ + Name: stringPtr("name"), + Address: ":8001", + Handler: someHandler, + Logger: someLogger, + ShutdownTimeout: durationPtr(time.Second), + }, + }, + "merge filled with empty": { + settings: Settings{ + Name: stringPtr("name"), + Address: ":8001", + Handler: someHandler, + Logger: someLogger, + ShutdownTimeout: durationPtr(time.Second), + }, + expected: Settings{ + Name: stringPtr("name"), + Address: ":8001", + Handler: someHandler, + Logger: someLogger, + ShutdownTimeout: durationPtr(time.Second), + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + testCase.settings.MergeWith(testCase.other) + + assert.Equal(t, testCase.expected, testCase.settings) + }) + } +} + +func Test_Settings_OverrideWith(t *testing.T) { + t.Parallel() + + someHandler := http.NewServeMux() + someLogger := &testLogger{} + + testCases := map[string]struct { + settings Settings + other Settings + expected Settings + }{ + "override empty with empty": {}, + "override empty with filled": { + other: Settings{ + Name: stringPtr("name"), + Address: ":8001", + Handler: someHandler, + Logger: someLogger, + ShutdownTimeout: durationPtr(time.Second), + }, + expected: Settings{ + Name: stringPtr("name"), + Address: ":8001", + Handler: someHandler, + Logger: someLogger, + ShutdownTimeout: durationPtr(time.Second), + }, + }, + "override filled with empty": { + settings: Settings{ + Name: stringPtr("name"), + Address: ":8001", + Handler: someHandler, + Logger: someLogger, + ShutdownTimeout: durationPtr(time.Second), + }, + expected: Settings{ + Name: stringPtr("name"), + Address: ":8001", + Handler: someHandler, + Logger: someLogger, + ShutdownTimeout: durationPtr(time.Second), + }, + }, + "override filled with filled": { + settings: Settings{ + Name: stringPtr("name"), + Address: ":8001", + Handler: someHandler, + Logger: someLogger, + ShutdownTimeout: durationPtr(time.Second), + }, + other: Settings{ + Name: stringPtr("name2"), + Address: ":8002", + ShutdownTimeout: durationPtr(time.Hour), + }, + expected: Settings{ + Name: stringPtr("name2"), + Address: ":8002", + Handler: someHandler, + Logger: someLogger, + ShutdownTimeout: durationPtr(time.Hour), + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + testCase.settings.OverrideWith(testCase.other) + + assert.Equal(t, testCase.expected, testCase.settings) + }) + } +} + +func Test_Settings_Validate(t *testing.T) { + t.Parallel() + + someHandler := http.NewServeMux() + someLogger := &testLogger{} + + testCases := map[string]struct { + settings Settings + errWrapped error + errMessage string + }{ + "bad address": { + settings: Settings{ + Address: "noport", + }, + errWrapped: address.ErrValueNotValid, + errMessage: "value is not valid: address noport: missing port in address", + }, + "nil handler": { + settings: Settings{ + Address: ":8000", + }, + errWrapped: ErrHandlerIsNotSet, + errMessage: ErrHandlerIsNotSet.Error(), + }, + "nil logger": { + settings: Settings{ + Address: ":8000", + Handler: someHandler, + }, + errWrapped: ErrLoggerIsNotSet, + errMessage: ErrLoggerIsNotSet.Error(), + }, + "shutdown timeout too small": { + settings: Settings{ + Address: ":8000", + Handler: someHandler, + Logger: someLogger, + ShutdownTimeout: durationPtr(time.Millisecond), + }, + errWrapped: ErrShutdownTimeoutTooSmall, + errMessage: "shutdown timeout is too small: 1ms must be at least 5ms", + }, + "valid settings": { + settings: Settings{ + Address: ":8000", + Handler: someHandler, + Logger: someLogger, + ShutdownTimeout: durationPtr(time.Second), + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + err := testCase.settings.Validate() + + assert.ErrorIs(t, err, testCase.errWrapped) + if err != nil { + assert.EqualError(t, err, testCase.errMessage) + } + }) + } +} + +func Test_Settings_String(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + settings Settings + s string + }{ + "all values": { + settings: Settings{ + Name: stringPtr("name"), + Address: ":8000", + ShutdownTimeout: durationPtr(time.Second), + }, + s: `Name HTTP server settings: +├── Listening address: :8000 +└── Shutdown timeout: 1s`, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + s := testCase.settings.String() + + assert.Equal(t, testCase.s, s) + }) + } +} diff --git a/internal/pprof/helpers_test.go b/internal/pprof/helpers_test.go new file mode 100644 index 00000000..168a5be2 --- /dev/null +++ b/internal/pprof/helpers_test.go @@ -0,0 +1,36 @@ +package pprof + +import ( + "regexp" + "time" + + gomock "github.com/golang/mock/gomock" +) + +func boolPtr(b bool) *bool { return &b } +func stringPtr(s string) *string { return &s } +func durationPtr(d time.Duration) *time.Duration { return &d } + +var _ gomock.Matcher = (*regexMatcher)(nil) + +type regexMatcher struct { + regexp *regexp.Regexp +} + +func (r *regexMatcher) Matches(x interface{}) bool { + s, ok := x.(string) + if !ok { + return false + } + return r.regexp.MatchString(s) +} + +func (r *regexMatcher) String() string { + return "regular expression " + r.regexp.String() +} + +func newRegexMatcher(regex string) *regexMatcher { + return ®exMatcher{ + regexp: regexp.MustCompile(regex), + } +} diff --git a/internal/pprof/logger_mock_test.go b/internal/pprof/logger_mock_test.go new file mode 100644 index 00000000..62e8d4a4 --- /dev/null +++ b/internal/pprof/logger_mock_test.go @@ -0,0 +1,70 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/qdm12/gluetun/internal/httpserver (interfaces: Logger) + +// Package pprof is a generated GoMock package. +package pprof + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockLogger is a mock of Logger interface. +type MockLogger struct { + ctrl *gomock.Controller + recorder *MockLoggerMockRecorder +} + +// MockLoggerMockRecorder is the mock recorder for MockLogger. +type MockLoggerMockRecorder struct { + mock *MockLogger +} + +// NewMockLogger creates a new mock instance. +func NewMockLogger(ctrl *gomock.Controller) *MockLogger { + mock := &MockLogger{ctrl: ctrl} + mock.recorder = &MockLoggerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockLogger) EXPECT() *MockLoggerMockRecorder { + return m.recorder +} + +// Error mocks base method. +func (m *MockLogger) Error(arg0 string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Error", arg0) +} + +// Error indicates an expected call of Error. +func (mr *MockLoggerMockRecorder) Error(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Error", reflect.TypeOf((*MockLogger)(nil).Error), arg0) +} + +// Info mocks base method. +func (m *MockLogger) Info(arg0 string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Info", arg0) +} + +// Info indicates an expected call of Info. +func (mr *MockLoggerMockRecorder) Info(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Info", reflect.TypeOf((*MockLogger)(nil).Info), arg0) +} + +// Warn mocks base method. +func (m *MockLogger) Warn(arg0 string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Warn", arg0) +} + +// Warn indicates an expected call of Warn. +func (mr *MockLoggerMockRecorder) Warn(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Warn", reflect.TypeOf((*MockLogger)(nil).Warn), arg0) +} diff --git a/internal/pprof/server.go b/internal/pprof/server.go new file mode 100644 index 00000000..ad246eca --- /dev/null +++ b/internal/pprof/server.go @@ -0,0 +1,40 @@ +package pprof + +import ( + "fmt" + "net/http" + "net/http/pprof" + "runtime" + + "github.com/qdm12/gluetun/internal/httpserver" +) + +// New creates a new Pprof server and configure profiling +// with the settings given. It returns an error +// if one of the settings is not valid. +func New(settings Settings) (server *httpserver.Server, err error) { + runtime.SetBlockProfileRate(settings.BlockProfileRate) + runtime.SetMutexProfileFraction(settings.MutexProfileRate) + + handler := http.NewServeMux() + handler.HandleFunc("/debug/pprof/", pprof.Index) + handler.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) + handler.HandleFunc("/debug/pprof/profile", pprof.Profile) + handler.HandleFunc("/debug/pprof/symbol", pprof.Symbol) + handler.HandleFunc("/debug/pprof/trace", pprof.Trace) + handler.Handle("/debug/pprof/block", pprof.Handler("block")) + handler.Handle("/debug/pprof/goroutine", pprof.Handler("goroutine")) + handler.Handle("/debug/pprof/heap", pprof.Handler("heap")) + handler.Handle("/debug/pprof/threadcreate", pprof.Handler("threadcreate")) + + httpServerName := "pprof" + settings.HTTPServer.Name = &httpServerName + settings.HTTPServer.Handler = handler + + settings.SetDefaults() + if err = settings.Validate(); err != nil { + return nil, fmt.Errorf("pprof settings failed validation: %w", err) + } + + return httpserver.New(settings.HTTPServer) +} diff --git a/internal/pprof/server_test.go b/internal/pprof/server_test.go new file mode 100644 index 00000000..e4d91123 --- /dev/null +++ b/internal/pprof/server_test.go @@ -0,0 +1,124 @@ +package pprof + +import ( + "context" + "io/ioutil" + "net/http" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/qdm12/gluetun/internal/httpserver" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +//go:generate mockgen -destination=logger_mock_test.go -package $GOPACKAGE github.com/qdm12/gluetun/internal/httpserver Logger + +func Test_Server(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + const address = "127.0.0.1:0" + logger := NewMockLogger(ctrl) + + logger.EXPECT().Info(newRegexMatcher("^pprof http server listening on 127.0.0.1:[1-9][0-9]{0,4}$")) + logger.EXPECT().Warn("pprof http server shutting down: context canceled") + + const httpServerShutdownTimeout = 10 * time.Second // 10s in case test worker is slow + settings := Settings{ + HTTPServer: httpserver.Settings{ + Address: address, + Logger: logger, + ShutdownTimeout: durationPtr(httpServerShutdownTimeout), + }, + } + + server, err := New(settings) + require.NoError(t, err) + require.NotNil(t, server) + + ctx, cancel := context.WithCancel(context.Background()) + ready := make(chan struct{}) + done := make(chan struct{}) + + go server.Run(ctx, ready, done) + + select { + case <-ready: + case err := <-done: + t.Fatalf("server crashed before being ready: %s", err) + } + + serverAddress := server.GetAddress() + + const clientTimeout = 2 * time.Second + httpClient := &http.Client{Timeout: clientTimeout} + + pathsToCheck := []string{ + "debug/pprof/", + "debug/pprof/cmdline", + "debug/pprof/profile?seconds=1", + "debug/pprof/symbol", + "debug/pprof/trace?seconds=1", + "debug/pprof/block", + "debug/pprof/goroutine", + "debug/pprof/heap", + "debug/pprof/threadcreate", + } + + type httpResult struct { + url string + response *http.Response + err error + } + results := make(chan httpResult) + + for _, pathToCheck := range pathsToCheck { + url := "http://" + serverAddress + "/" + pathToCheck + + request, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + require.NoError(t, err) + + go func(client *http.Client, request *http.Request, results chan<- httpResult) { + response, err := client.Do(request) //nolint:bodyclose + results <- httpResult{ + url: request.URL.String(), + response: response, + err: err, + } + }(httpClient, request, results) + } + + for range pathsToCheck { + httpResult := <-results + + require.NoErrorf(t, httpResult.err, "unexpected error for URL %s: %s", httpResult.url, httpResult.err) + assert.Equalf(t, http.StatusOK, httpResult.response.StatusCode, + "unexpected status code for URL %s: %s", httpResult.url, http.StatusText(httpResult.response.StatusCode)) + + b, err := ioutil.ReadAll(httpResult.response.Body) + require.NoErrorf(t, err, "unexpected error for URL %s: %s", httpResult.url, err) + assert.NotEmptyf(t, b, "response body is empty for URL %s", httpResult.url) + + err = httpResult.response.Body.Close() + assert.NoErrorf(t, err, "unexpected error for URL %s: %s", httpResult.url, err) + } + + cancel() + <-done +} + +func Test_Server_BadSettings(t *testing.T) { + t.Parallel() + + settings := Settings{ + BlockProfileRate: -1, + } + + server, err := New(settings) + assert.Nil(t, server) + assert.ErrorIs(t, err, ErrBlockProfileRateNegative) + const expectedErrMessage = "pprof settings failed validation: block profile rate cannot be negative" + assert.EqualError(t, err, expectedErrMessage) +} diff --git a/internal/pprof/settings.go b/internal/pprof/settings.go new file mode 100644 index 00000000..5d0ece0e --- /dev/null +++ b/internal/pprof/settings.go @@ -0,0 +1,96 @@ +package pprof + +import ( + "errors" + + "github.com/qdm12/gluetun/internal/configuration/settings/helpers" + "github.com/qdm12/gluetun/internal/httpserver" + "github.com/qdm12/gotree" +) + +// Settings are the settings for the Pprof service. +type Settings struct { + // Enabled can be false or true. + // It defaults to false. + Enabled *bool + // See runtime.SetBlockProfileRate + // Set to 0 to disable profiling. + BlockProfileRate int + // See runtime.SetMutexProfileFraction + // Set to 0 to disable profiling. + MutexProfileRate int + // HTTPServer contains settings to configure + // the HTTP server serving pprof data. + HTTPServer httpserver.Settings +} + +func (s *Settings) SetDefaults() { + s.Enabled = helpers.DefaultBool(s.Enabled, false) + s.HTTPServer.Name = helpers.DefaultStringPtr(s.HTTPServer.Name, "pprof") + s.HTTPServer.Address = helpers.DefaultString(s.HTTPServer.Address, "localhost:6060") + s.HTTPServer.SetDefaults() +} + +func (s Settings) Copy() (copied Settings) { + return Settings{ + Enabled: helpers.CopyBoolPtr(s.Enabled), + BlockProfileRate: s.BlockProfileRate, + MutexProfileRate: s.MutexProfileRate, + HTTPServer: s.HTTPServer.Copy(), + } +} + +func (s *Settings) MergeWith(other Settings) { + s.Enabled = helpers.MergeWithBool(s.Enabled, other.Enabled) + s.BlockProfileRate = helpers.MergeWithInt(s.BlockProfileRate, other.BlockProfileRate) + s.MutexProfileRate = helpers.MergeWithInt(s.MutexProfileRate, other.MutexProfileRate) + s.HTTPServer.MergeWith(other.HTTPServer) +} + +func (s *Settings) OverrideWith(other Settings) { + s.Enabled = helpers.OverrideWithBool(s.Enabled, other.Enabled) + s.BlockProfileRate = helpers.OverrideWithInt(s.BlockProfileRate, other.BlockProfileRate) + s.MutexProfileRate = helpers.OverrideWithInt(s.MutexProfileRate, other.MutexProfileRate) + s.HTTPServer.OverrideWith(other.HTTPServer) +} + +var ( + ErrBlockProfileRateNegative = errors.New("block profile rate cannot be negative") + ErrMutexProfileRateNegative = errors.New("mutex profile rate cannot be negative") +) + +func (s Settings) Validate() (err error) { + if s.BlockProfileRate < 0 { + return ErrBlockProfileRateNegative + } + + if s.MutexProfileRate < 0 { + return ErrMutexProfileRateNegative + } + + return s.HTTPServer.Validate() +} + +func (s Settings) ToLinesNode() (node *gotree.Node) { + if !*s.Enabled { + return nil + } + + node = gotree.New("Pprof settings:") + + if s.BlockProfileRate > 0 { + node.Appendf("Block profile rate: %d", s.BlockProfileRate) + } + + if s.MutexProfileRate > 0 { + node.Appendf("Mutex profile rate: %d", s.MutexProfileRate) + } + + node.AppendNode(s.HTTPServer.ToLinesNode()) + + return node +} + +func (s Settings) String() string { + return s.ToLinesNode().String() +} diff --git a/internal/pprof/settings_test.go b/internal/pprof/settings_test.go new file mode 100644 index 00000000..ac5b953e --- /dev/null +++ b/internal/pprof/settings_test.go @@ -0,0 +1,352 @@ +package pprof + +import ( + "net/http" + "testing" + "time" + + "github.com/qdm12/gluetun/internal/httpserver" + "github.com/qdm12/govalid/address" + "github.com/stretchr/testify/assert" +) + +func Test_Settings_SetDefaults(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + initial Settings + expected Settings + }{ + "empty settings": { + expected: Settings{ + Enabled: boolPtr(false), + HTTPServer: httpserver.Settings{ + Name: stringPtr("pprof"), + Address: "localhost:6060", + ShutdownTimeout: durationPtr(3 * time.Second), + }, + }, + }, + "non empty settings": { + initial: Settings{ + Enabled: boolPtr(true), + BlockProfileRate: 1, + MutexProfileRate: 1, + HTTPServer: httpserver.Settings{ + Name: stringPtr("custom"), + Address: ":6061", + ShutdownTimeout: durationPtr(time.Second), + }, + }, + expected: Settings{ + Enabled: boolPtr(true), + BlockProfileRate: 1, + MutexProfileRate: 1, + HTTPServer: httpserver.Settings{ + Name: stringPtr("custom"), + Address: ":6061", + ShutdownTimeout: durationPtr(time.Second), + }, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + testCase.initial.SetDefaults() + + assert.Equal(t, testCase.expected, testCase.initial) + }) + } +} + +func Test_Settings_Copy(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + initial Settings + expected Settings + }{ + "empty settings": {}, + "non empty settings": { + initial: Settings{ + Enabled: boolPtr(true), + BlockProfileRate: 1, + MutexProfileRate: 1, + HTTPServer: httpserver.Settings{ + Name: stringPtr("custom"), + Address: ":6061", + ShutdownTimeout: durationPtr(time.Second), + }, + }, + expected: Settings{ + Enabled: boolPtr(true), + BlockProfileRate: 1, + MutexProfileRate: 1, + HTTPServer: httpserver.Settings{ + Name: stringPtr("custom"), + Address: ":6061", + ShutdownTimeout: durationPtr(time.Second), + }, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + copied := testCase.initial.Copy() + + assert.Equal(t, testCase.expected, copied) + }) + } +} + +func Test_Settings_MergeWith(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + settings Settings + other Settings + expected Settings + }{ + "merge empty with empty": {}, + "merge empty with filled": { + other: Settings{ + Enabled: boolPtr(true), + BlockProfileRate: 1, + MutexProfileRate: 1, + HTTPServer: httpserver.Settings{ + Address: ":8001", + }, + }, + expected: Settings{ + Enabled: boolPtr(true), + BlockProfileRate: 1, + MutexProfileRate: 1, + HTTPServer: httpserver.Settings{ + Address: ":8001", + }, + }, + }, + "merge filled with empty": { + settings: Settings{ + Enabled: boolPtr(true), + BlockProfileRate: 1, + MutexProfileRate: 1, + HTTPServer: httpserver.Settings{ + Address: ":8001", + }, + }, + expected: Settings{ + Enabled: boolPtr(true), + BlockProfileRate: 1, + MutexProfileRate: 1, + HTTPServer: httpserver.Settings{ + Address: ":8001", + }, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + testCase.settings.MergeWith(testCase.other) + + assert.Equal(t, testCase.expected, testCase.settings) + }) + } +} + +func Test_Settings_OverrideWith(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + settings Settings + other Settings + expected Settings + }{ + "override empty with empty": {}, + "override empty with filled": { + other: Settings{ + Enabled: boolPtr(true), + BlockProfileRate: 1, + MutexProfileRate: 1, + HTTPServer: httpserver.Settings{ + Address: ":8001", + }, + }, + expected: Settings{ + Enabled: boolPtr(true), + BlockProfileRate: 1, + MutexProfileRate: 1, + HTTPServer: httpserver.Settings{ + Address: ":8001", + }, + }, + }, + "override filled with empty": { + settings: Settings{ + Enabled: boolPtr(true), + BlockProfileRate: 1, + MutexProfileRate: 1, + HTTPServer: httpserver.Settings{ + Address: ":8001", + }, + }, + expected: Settings{ + Enabled: boolPtr(true), + BlockProfileRate: 1, + MutexProfileRate: 1, + HTTPServer: httpserver.Settings{ + Address: ":8001", + }, + }, + }, + "override filled with filled": { + settings: Settings{ + Enabled: boolPtr(false), + BlockProfileRate: 1, + MutexProfileRate: 1, + HTTPServer: httpserver.Settings{ + Address: ":8001", + }, + }, + other: Settings{ + Enabled: boolPtr(true), + BlockProfileRate: 2, + MutexProfileRate: 3, + HTTPServer: httpserver.Settings{ + Address: ":8002", + }, + }, + expected: Settings{ + Enabled: boolPtr(true), + BlockProfileRate: 2, + MutexProfileRate: 3, + HTTPServer: httpserver.Settings{ + Address: ":8002", + }, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + testCase.settings.OverrideWith(testCase.other) + + assert.Equal(t, testCase.expected, testCase.settings) + }) + } +} + +func Test_Settings_Validate(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + settings Settings + errWrapped error + errMessage string + }{ + "negative block profile rate": { + settings: Settings{ + BlockProfileRate: -1, + }, + errWrapped: ErrBlockProfileRateNegative, + errMessage: ErrBlockProfileRateNegative.Error(), + }, + "negative mutex profile rate": { + settings: Settings{ + MutexProfileRate: -1, + }, + errWrapped: ErrMutexProfileRateNegative, + errMessage: ErrMutexProfileRateNegative.Error(), + }, + "http server validation error": { + settings: Settings{ + HTTPServer: httpserver.Settings{}, + }, + errWrapped: address.ErrValueNotValid, + errMessage: "value is not valid: missing port in address", + }, + "valid settings": { + settings: Settings{ + HTTPServer: httpserver.Settings{ + Address: ":8000", + Handler: http.NewServeMux(), + Logger: &MockLogger{}, + ShutdownTimeout: durationPtr(time.Second), + }, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + err := testCase.settings.Validate() + + assert.ErrorIs(t, err, testCase.errWrapped) + if err != nil { + assert.EqualError(t, err, testCase.errMessage) + } + }) + } +} + +func Test_Settings_String(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + settings Settings + s string + }{ + "disabled pprof": { + settings: Settings{ + Enabled: boolPtr(false), + }, + }, + "all values": { + settings: Settings{ + Enabled: boolPtr(true), + BlockProfileRate: 2, + MutexProfileRate: 1, + HTTPServer: httpserver.Settings{ + Name: stringPtr("name"), + Address: ":8000", + ShutdownTimeout: durationPtr(time.Second), + }, + }, + s: `Pprof settings: +├── Block profile rate: 2 +├── Mutex profile rate: 1 +└── Name HTTP server settings: + ├── Listening address: :8000 + └── Shutdown timeout: 1s`, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + s := testCase.settings.String() + + assert.Equal(t, testCase.s, s) + }) + } +}