feat(server): role based authentication system (#2434)

- Parse toml configuration file, see https://github.com/qdm12/gluetun-wiki/blob/main/setup/advanced/control-server.md#authentication
- Retro-compatible with existing AND documented routes, until after v3.41 release
- Log a warning if an unprotected-by-default route is accessed unprotected
- Authentication methods: none, apikey, basic
- `genkey` command to generate API keys

Co-authored-by: Joe Jose <45399349+joejose97@users.noreply.github.com>
This commit is contained in:
Quentin McGaw
2024-09-18 13:29:36 +02:00
committed by GitHub
parent 07651683f9
commit a2e76e1683
25 changed files with 920 additions and 10 deletions

View File

@@ -0,0 +1,36 @@
package auth
import (
"crypto/sha256"
"crypto/subtle"
"net/http"
)
type apiKeyMethod struct {
apiKeyDigest [32]byte
}
func newAPIKeyMethod(apiKey string) *apiKeyMethod {
return &apiKeyMethod{
apiKeyDigest: sha256.Sum256([]byte(apiKey)),
}
}
// equal returns true if another auth checker is equal.
// This is used to deduplicate checkers for a particular route.
func (a *apiKeyMethod) equal(other authorizationChecker) bool {
otherTokenMethod, ok := other.(*apiKeyMethod)
if !ok {
return false
}
return a.apiKeyDigest == otherTokenMethod.apiKeyDigest
}
func (a *apiKeyMethod) isAuthorized(_ http.Header, request *http.Request) bool {
xAPIKey := request.Header.Get("X-API-Key")
if xAPIKey == "" {
xAPIKey = request.URL.Query().Get("api_key")
}
xAPIKeyDigest := sha256.Sum256([]byte(xAPIKey))
return subtle.ConstantTimeCompare(xAPIKeyDigest[:], a.apiKeyDigest[:]) == 1
}

View File

@@ -0,0 +1,37 @@
package auth
import (
"crypto/sha256"
"crypto/subtle"
"net/http"
)
type basicAuthMethod struct {
authDigest [32]byte
}
func newBasicAuthMethod(username, password string) *basicAuthMethod {
return &basicAuthMethod{
authDigest: sha256.Sum256([]byte(username + password)),
}
}
// equal returns true if another auth checker is equal.
// This is used to deduplicate checkers for a particular route.
func (a *basicAuthMethod) equal(other authorizationChecker) bool {
otherBasicMethod, ok := other.(*basicAuthMethod)
if !ok {
return false
}
return a.authDigest == otherBasicMethod.authDigest
}
func (a *basicAuthMethod) isAuthorized(headers http.Header, request *http.Request) bool {
username, password, ok := request.BasicAuth()
if !ok {
headers.Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
return false
}
requestAuthDigest := sha256.Sum256([]byte(username + password))
return subtle.ConstantTimeCompare(a.authDigest[:], requestAuthDigest[:]) == 1
}

View File

@@ -0,0 +1,35 @@
package auth
import (
"errors"
"fmt"
"os"
"github.com/pelletier/go-toml/v2"
)
// Read reads the toml file specified by the filepath given.
// If the file does not exist, it returns empty settings and no error.
func Read(filepath string) (settings Settings, err error) {
file, err := os.Open(filepath)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return Settings{}, nil
}
return settings, fmt.Errorf("opening file: %w", err)
}
decoder := toml.NewDecoder(file)
decoder.DisallowUnknownFields()
err = decoder.Decode(&settings)
if err == nil {
return settings, nil
}
strictErr := new(toml.StrictMissingError)
ok := errors.As(err, &strictErr)
if !ok {
return settings, fmt.Errorf("toml decoding file: %w", err)
}
return settings, fmt.Errorf("toml decoding file: %w:\n%s",
strictErr, strictErr.String())
}

View File

@@ -0,0 +1,80 @@
package auth
import (
"io/fs"
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Read reads the toml file specified by the filepath given.
func Test_Read(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
fileContent string
settings Settings
errMessage string
}{
"empty_file": {},
"malformed_toml": {
fileContent: "this is not a toml file",
errMessage: `toml decoding file: toml: expected character =`,
},
"unknown_field": {
fileContent: `unknown = "what is this"`,
errMessage: `toml decoding file: strict mode: fields in the document are missing in the target struct:
1| unknown = "what is this"
| ~~~~~~~ missing field`,
},
"filled_settings": {
fileContent: `[[roles]]
name = "public"
auth = "none"
routes = ["GET /v1/vpn/status", "PUT /v1/vpn/status"]
[[roles]]
name = "client"
auth = "apikey"
apikey = "xyz"
routes = ["GET /v1/vpn/status"]
`,
settings: Settings{
Roles: []Role{{
Name: "public",
Auth: AuthNone,
Routes: []string{"GET /v1/vpn/status", "PUT /v1/vpn/status"},
}, {
Name: "client",
Auth: AuthAPIKey,
APIKey: "xyz",
Routes: []string{"GET /v1/vpn/status"},
}},
},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
filepath := tempDir + "/config.toml"
const permissions fs.FileMode = 0600
err := os.WriteFile(filepath, []byte(testCase.fileContent), permissions)
require.NoError(t, err)
settings, err := Read(filepath)
assert.Equal(t, testCase.settings, settings)
if testCase.errMessage != "" {
assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
}
})
}
}

View File

@@ -0,0 +1,22 @@
package auth
func andStrings(strings []string) (result string) {
return joinStrings(strings, "and")
}
func joinStrings(strings []string, lastJoin string) (result string) {
if len(strings) == 0 {
return ""
}
result = strings[0]
for i := 1; i < len(strings); i++ {
if i < len(strings)-1 {
result += ", " + strings[i]
} else {
result += " " + lastJoin + " " + strings[i]
}
}
return result
}

View File

@@ -0,0 +1,6 @@
package auth
type DebugLogger interface {
Debugf(format string, args ...any)
Warnf(format string, args ...any)
}

View File

@@ -0,0 +1,8 @@
package auth
import "net/http"
type authorizationChecker interface {
equal(other authorizationChecker) bool
isAuthorized(headers http.Header, request *http.Request) bool
}

View File

@@ -0,0 +1,47 @@
package auth
import (
"fmt"
)
type internalRole struct {
name string
checker authorizationChecker
}
func settingsToLookupMap(settings Settings) (routeToRoles map[string][]internalRole, err error) {
routeToRoles = make(map[string][]internalRole)
for _, role := range settings.Roles {
var checker authorizationChecker
switch role.Auth {
case AuthNone:
checker = newNoneMethod()
case AuthAPIKey:
checker = newAPIKeyMethod(role.APIKey)
case AuthBasic:
checker = newBasicAuthMethod(role.Username, role.Password)
default:
return nil, fmt.Errorf("%w: %s", ErrMethodNotSupported, role.Auth)
}
iRole := internalRole{
name: role.Name,
checker: checker,
}
for _, route := range role.Routes {
checkerExists := false
for _, role := range routeToRoles[route] {
if role.checker.equal(iRole.checker) {
checkerExists = true
break
}
}
if checkerExists {
// even if the role name is different, if the checker is the same, skip it.
continue
}
routeToRoles[route] = append(routeToRoles[route], iRole)
}
}
return routeToRoles, nil
}

View File

@@ -0,0 +1,60 @@
package auth
import (
"testing"
"github.com/stretchr/testify/assert"
)
// Read reads the toml file specified by the filepath given.
func Test_settingsToLookupMap(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
settings Settings
routeToRoles map[string][]internalRole
errWrapped error
errMessage string
}{
"empty_settings": {
routeToRoles: map[string][]internalRole{},
},
"auth_method_not_supported": {
settings: Settings{
Roles: []Role{{Name: "a", Auth: "bad"}},
},
errWrapped: ErrMethodNotSupported,
errMessage: "authentication method not supported: bad",
},
"success": {
settings: Settings{
Roles: []Role{
{Name: "a", Auth: AuthNone, Routes: []string{"GET /path"}},
{Name: "b", Auth: AuthNone, Routes: []string{"GET /path", "PUT /path"}},
},
},
routeToRoles: map[string][]internalRole{
"GET /path": {
{name: "a", checker: newNoneMethod()}, // deduplicated method
},
"PUT /path": {
{name: "b", checker: newNoneMethod()},
}},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
routeToRoles, err := settingsToLookupMap(testCase.settings)
assert.Equal(t, testCase.routeToRoles, routeToRoles)
assert.ErrorIs(t, err, testCase.errWrapped)
if testCase.errWrapped != nil {
assert.EqualError(t, err, testCase.errMessage)
}
})
}
}

View File

@@ -0,0 +1,111 @@
package auth
import (
"fmt"
"net/http"
)
func New(settings Settings, debugLogger DebugLogger) (
middleware func(http.Handler) http.Handler,
err error) {
routeToRoles, err := settingsToLookupMap(settings)
if err != nil {
return nil, fmt.Errorf("converting settings to lookup maps: %w", err)
}
//nolint:goconst
return func(handler http.Handler) http.Handler {
return &authHandler{
childHandler: handler,
routeToRoles: routeToRoles,
unprotectedRoutes: map[string]struct{}{
http.MethodGet + " /openvpn/actions/restart": {},
http.MethodGet + " /unbound/actions/restart": {},
http.MethodGet + " /updater/restart": {},
http.MethodGet + " /v1/version": {},
http.MethodGet + " /v1/vpn/status": {},
http.MethodPut + " /v1/vpn/status": {},
// GET /v1/vpn/settings is protected by default
// PUT /v1/vpn/settings is protected by default
http.MethodGet + " /v1/openvpn/status": {},
http.MethodPut + " /v1/openvpn/status": {},
http.MethodGet + " /v1/openvpn/portforwarded": {},
// GET /v1/openvpn/settings is protected by default
http.MethodGet + " /v1/dns/status": {},
http.MethodPut + " /v1/dns/status": {},
http.MethodGet + " /v1/updater/status": {},
http.MethodPut + " /v1/updater/status": {},
http.MethodGet + " /v1/publicip/ip": {},
},
logger: debugLogger,
}
}, nil
}
type authHandler struct {
childHandler http.Handler
routeToRoles map[string][]internalRole
unprotectedRoutes map[string]struct{} // TODO v3.41.0 remove
logger DebugLogger
}
func (h *authHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
route := request.Method + " " + request.URL.Path
roles := h.routeToRoles[route]
if len(roles) == 0 {
h.logger.Debugf("no authentication role defined for route %s", route)
http.Error(writer, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}
responseHeader := make(http.Header, 0)
for _, role := range roles {
if !role.checker.isAuthorized(responseHeader, request) {
continue
}
h.warnIfUnprotectedByDefault(role, route) // TODO v3.41.0 remove
h.logger.Debugf("access to route %s authorized for role %s", route, role.name)
h.childHandler.ServeHTTP(writer, request)
return
}
// Flush out response headers if all roles failed to authenticate
for headerKey, headerValues := range responseHeader {
for _, headerValue := range headerValues {
writer.Header().Add(headerKey, headerValue)
}
}
allRoleNames := make([]string, len(roles))
for i, role := range roles {
allRoleNames[i] = role.name
}
h.logger.Debugf("access to route %s unauthorized after checking for roles %s",
route, andStrings(allRoleNames))
http.Error(writer, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
}
func (h *authHandler) warnIfUnprotectedByDefault(role internalRole, route string) {
// TODO v3.41.0 remove
if role.name != "public" {
// custom role name, allow none authentication to be specified
return
}
_, isNoneChecker := role.checker.(*noneMethod)
if !isNoneChecker {
// not the none authentication method
return
}
_, isUnprotectedByDefault := h.unprotectedRoutes[route]
if !isUnprotectedByDefault {
// route is not unprotected by default, so this is a user decision
return
}
h.logger.Warnf("route %s is unprotected by default, "+
"please set up authentication following the documentation at "+
"https://github.com/qdm12/gluetun-wiki/setup/advanced/control-server.md#authentication "+
"since this will become no longer publicly accessible after release v3.40.",
route)
}

View File

@@ -0,0 +1,124 @@
package auth
import (
"context"
"io"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_authHandler_ServeHTTP(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
settings Settings
makeLogger func(ctrl *gomock.Controller) *MockDebugLogger
requestMethod string
requestPath string
statusCode int
responseBody string
}{
"route_has_no_role": {
settings: Settings{
Roles: []Role{
{Name: "role1", Auth: AuthNone, Routes: []string{"GET /a"}},
},
},
makeLogger: func(ctrl *gomock.Controller) *MockDebugLogger {
logger := NewMockDebugLogger(ctrl)
logger.EXPECT().Debugf("no authentication role defined for route %s", "GET /b")
return logger
},
requestMethod: http.MethodGet,
requestPath: "/b",
statusCode: http.StatusUnauthorized,
responseBody: "Unauthorized\n",
},
"authorized_unprotected_by_default": {
settings: Settings{
Roles: []Role{
{Name: "public", Auth: AuthNone, Routes: []string{"GET /v1/vpn/status"}},
},
},
makeLogger: func(ctrl *gomock.Controller) *MockDebugLogger {
logger := NewMockDebugLogger(ctrl)
logger.EXPECT().Warnf("route %s is unprotected by default, "+
"please set up authentication following the documentation at "+
"https://github.com/qdm12/gluetun-wiki/setup/advanced/control-server.md#authentication "+
"since this will become no longer publicly accessible after release v3.40.",
"GET /v1/vpn/status")
logger.EXPECT().Debugf("access to route %s authorized for role %s",
"GET /v1/vpn/status", "public")
return logger
},
requestMethod: http.MethodGet,
requestPath: "/v1/vpn/status",
statusCode: http.StatusOK,
},
"authorized_none": {
settings: Settings{
Roles: []Role{
{Name: "role1", Auth: AuthNone, Routes: []string{"GET /a"}},
},
},
makeLogger: func(ctrl *gomock.Controller) *MockDebugLogger {
logger := NewMockDebugLogger(ctrl)
logger.EXPECT().Debugf("access to route %s authorized for role %s",
"GET /a", "role1")
return logger
},
requestMethod: http.MethodGet,
requestPath: "/a",
statusCode: http.StatusOK,
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
var debugLogger DebugLogger
if testCase.makeLogger != nil {
debugLogger = testCase.makeLogger(ctrl)
}
middleware, err := New(testCase.settings, debugLogger)
require.NoError(t, err)
childHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})
handler := middleware(childHandler)
server := httptest.NewServer(handler)
t.Cleanup(server.Close)
client := server.Client()
requestURL, err := url.JoinPath(server.URL, testCase.requestPath)
require.NoError(t, err)
request, err := http.NewRequestWithContext(context.Background(),
testCase.requestMethod, requestURL, nil)
require.NoError(t, err)
response, err := client.Do(request)
require.NoError(t, err)
t.Cleanup(func() {
err = response.Body.Close()
assert.NoError(t, err)
})
assert.Equal(t, testCase.statusCode, response.StatusCode)
body, err := io.ReadAll(response.Body)
require.NoError(t, err)
assert.Equal(t, testCase.responseBody, string(body))
})
}
}

View File

@@ -0,0 +1,3 @@
package auth
//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . DebugLogger

View File

@@ -0,0 +1,68 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/qdm12/gluetun/internal/server/middlewares/auth (interfaces: DebugLogger)
// Package auth is a generated GoMock package.
package auth
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
)
// MockDebugLogger is a mock of DebugLogger interface.
type MockDebugLogger struct {
ctrl *gomock.Controller
recorder *MockDebugLoggerMockRecorder
}
// MockDebugLoggerMockRecorder is the mock recorder for MockDebugLogger.
type MockDebugLoggerMockRecorder struct {
mock *MockDebugLogger
}
// NewMockDebugLogger creates a new mock instance.
func NewMockDebugLogger(ctrl *gomock.Controller) *MockDebugLogger {
mock := &MockDebugLogger{ctrl: ctrl}
mock.recorder = &MockDebugLoggerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockDebugLogger) EXPECT() *MockDebugLoggerMockRecorder {
return m.recorder
}
// Debugf mocks base method.
func (m *MockDebugLogger) Debugf(arg0 string, arg1 ...interface{}) {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
m.ctrl.Call(m, "Debugf", varargs...)
}
// Debugf indicates an expected call of Debugf.
func (mr *MockDebugLoggerMockRecorder) Debugf(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debugf", reflect.TypeOf((*MockDebugLogger)(nil).Debugf), varargs...)
}
// Warnf mocks base method.
func (m *MockDebugLogger) Warnf(arg0 string, arg1 ...interface{}) {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
m.ctrl.Call(m, "Warnf", varargs...)
}
// Warnf indicates an expected call of Warnf.
func (mr *MockDebugLoggerMockRecorder) Warnf(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Warnf", reflect.TypeOf((*MockDebugLogger)(nil).Warnf), varargs...)
}

View File

@@ -0,0 +1,20 @@
package auth
import "net/http"
type noneMethod struct{}
func newNoneMethod() *noneMethod {
return &noneMethod{}
}
// equal returns true if another auth checker is equal.
// This is used to deduplicate checkers for a particular route.
func (n *noneMethod) equal(other authorizationChecker) bool {
_, ok := other.(*noneMethod)
return ok
}
func (n *noneMethod) isAuthorized(_ http.Header, _ *http.Request) bool {
return true
}

View File

@@ -0,0 +1,131 @@
package auth
import (
"errors"
"fmt"
"net/http"
"github.com/qdm12/gosettings"
"github.com/qdm12/gosettings/validate"
)
type Settings struct {
// Roles is a list of roles with their associated authentication
// and routes.
Roles []Role
}
func (s *Settings) SetDefaults() {
s.Roles = gosettings.DefaultSlice(s.Roles, []Role{{ // TODO v3.41.0 leave empty
Name: "public",
Auth: "none",
Routes: []string{
http.MethodGet + " /openvpn/actions/restart",
http.MethodGet + " /unbound/actions/restart",
http.MethodGet + " /updater/restart",
http.MethodGet + " /v1/version",
http.MethodGet + " /v1/vpn/status",
http.MethodPut + " /v1/vpn/status",
http.MethodGet + " /v1/openvpn/status",
http.MethodPut + " /v1/openvpn/status",
http.MethodGet + " /v1/openvpn/portforwarded",
http.MethodGet + " /v1/dns/status",
http.MethodPut + " /v1/dns/status",
http.MethodGet + " /v1/updater/status",
http.MethodPut + " /v1/updater/status",
http.MethodGet + " /v1/publicip/ip",
},
}})
}
func (s Settings) Validate() (err error) {
for i, role := range s.Roles {
err = role.validate()
if err != nil {
return fmt.Errorf("role %s (%d of %d): %w",
role.Name, i+1, len(s.Roles), err)
}
}
return nil
}
const (
AuthNone = "none"
AuthAPIKey = "apikey"
AuthBasic = "basic"
)
// Role contains the role name, authentication method name and
// routes that the role can access.
type Role struct {
// Name is the role name and is only used for documentation
// and in the authentication middleware debug logs.
Name string
// Auth is the authentication method to use, which can be 'none' or 'apikey'.
Auth string
// APIKey is the API key to use when using the 'apikey' authentication.
APIKey string
// Username for HTTP Basic authentication method.
Username string
// Password for HTTP Basic authentication method.
Password string
// Routes is a list of routes that the role can access in the format
// "HTTP_METHOD PATH", for example "GET /v1/vpn/status"
Routes []string
}
var (
ErrMethodNotSupported = errors.New("authentication method not supported")
ErrAPIKeyEmpty = errors.New("api key is empty")
ErrBasicUsernameEmpty = errors.New("username is empty")
ErrBasicPasswordEmpty = errors.New("password is empty")
ErrRouteNotSupported = errors.New("route not supported by the control server")
)
func (r Role) validate() (err error) {
err = validate.IsOneOf(r.Auth, AuthNone, AuthAPIKey, AuthBasic)
if err != nil {
return fmt.Errorf("%w: %s", ErrMethodNotSupported, r.Auth)
}
switch {
case r.Auth == AuthAPIKey && r.APIKey == "":
return fmt.Errorf("for role %s: %w", r.Name, ErrAPIKeyEmpty)
case r.Auth == AuthBasic && r.Username == "":
return fmt.Errorf("for role %s: %w", r.Name, ErrBasicUsernameEmpty)
case r.Auth == AuthBasic && r.Password == "":
return fmt.Errorf("for role %s: %w", r.Name, ErrBasicPasswordEmpty)
}
for i, route := range r.Routes {
_, ok := validRoutes[route]
if !ok {
return fmt.Errorf("route %d of %d: %w: %s",
i+1, len(r.Routes), ErrRouteNotSupported, route)
}
}
return nil
}
// WARNING: do not mutate programmatically.
var validRoutes = map[string]struct{}{ //nolint:gochecknoglobals
http.MethodGet + " /openvpn/actions/restart": {},
http.MethodGet + " /unbound/actions/restart": {},
http.MethodGet + " /updater/restart": {},
http.MethodGet + " /v1/version": {},
http.MethodGet + " /v1/vpn/status": {},
http.MethodPut + " /v1/vpn/status": {},
http.MethodGet + " /v1/vpn/settings": {},
http.MethodPut + " /v1/vpn/settings": {},
http.MethodGet + " /v1/openvpn/status": {},
http.MethodPut + " /v1/openvpn/status": {},
http.MethodGet + " /v1/openvpn/portforwarded": {},
http.MethodGet + " /v1/openvpn/settings": {},
http.MethodGet + " /v1/dns/status": {},
http.MethodPut + " /v1/dns/status": {},
http.MethodGet + " /v1/updater/status": {},
http.MethodPut + " /v1/updater/status": {},
http.MethodGet + " /v1/publicip/ip": {},
}