feat(server): patch VPN settings

- `PUT` at `/v1/vpn/settings`
- Undocumented, experimental for now
This commit is contained in:
Quentin McGaw
2022-08-21 23:36:48 +00:00
parent d685d78e74
commit 0bb320065e
8 changed files with 60 additions and 18 deletions

View File

@@ -431,7 +431,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
"http server", goroutine.OptionTimeout(defaultShutdownTimeout)) "http server", goroutine.OptionTimeout(defaultShutdownTimeout))
httpServer, err := server.New(httpServerCtx, controlServerAddress, controlServerLogging, httpServer, err := server.New(httpServerCtx, controlServerAddress, controlServerLogging,
logger.New(log.SetComponent("http server")), logger.New(log.SetComponent("http server")),
buildInfo, vpnLooper, portForwardLooper, unboundLooper, updaterLooper, publicIPLooper) buildInfo, vpnLooper, portForwardLooper, unboundLooper, updaterLooper, publicIPLooper, storage)
if err != nil { if err != nil {
return fmt.Errorf("cannot setup control server: %w", err) return fmt.Errorf("cannot setup control server: %w", err)
} }

View File

@@ -46,7 +46,7 @@ func (s *Settings) Validate(storage Storage) (err error) {
"version": s.Version.validate, "version": s.Version.validate,
// Pprof validation done in pprof constructor // Pprof validation done in pprof constructor
"VPN": func() error { "VPN": func() error {
return s.VPN.validate(storage) return s.VPN.Validate(storage)
}, },
} }
@@ -73,7 +73,7 @@ func (s *Settings) copy() (copied Settings) {
System: s.System.copy(), System: s.System.copy(),
Updater: s.Updater.copy(), Updater: s.Updater.copy(),
Version: s.Version.copy(), Version: s.Version.copy(),
VPN: s.VPN.copy(), VPN: s.VPN.Copy(),
Pprof: s.Pprof.Copy(), Pprof: s.Pprof.Copy(),
} }
} }
@@ -108,7 +108,7 @@ func (s *Settings) OverrideWith(other Settings,
patchedSettings.System.overrideWith(other.System) patchedSettings.System.overrideWith(other.System)
patchedSettings.Updater.overrideWith(other.Updater) patchedSettings.Updater.overrideWith(other.Updater)
patchedSettings.Version.overrideWith(other.Version) patchedSettings.Version.overrideWith(other.Version)
patchedSettings.VPN.overrideWith(other.VPN) patchedSettings.VPN.OverrideWith(other.VPN)
patchedSettings.Pprof.OverrideWith(other.Pprof) patchedSettings.Pprof.OverrideWith(other.Pprof)
err = patchedSettings.Validate(storage) err = patchedSettings.Validate(storage)
if err != nil { if err != nil {

View File

@@ -20,7 +20,7 @@ type VPN struct {
} }
// TODO v4 remove pointer for receiver (because of Surfshark). // TODO v4 remove pointer for receiver (because of Surfshark).
func (v *VPN) validate(storage Storage) (err error) { func (v *VPN) Validate(storage Storage) (err error) {
// Validate Type // Validate Type
validVPNTypes := []string{vpn.OpenVPN, vpn.Wireguard} validVPNTypes := []string{vpn.OpenVPN, vpn.Wireguard}
if !helpers.IsOneOf(v.Type, validVPNTypes...) { if !helpers.IsOneOf(v.Type, validVPNTypes...) {
@@ -48,7 +48,7 @@ func (v *VPN) validate(storage Storage) (err error) {
return nil return nil
} }
func (v *VPN) copy() (copied VPN) { func (v *VPN) Copy() (copied VPN) {
return VPN{ return VPN{
Type: v.Type, Type: v.Type,
Provider: v.Provider.copy(), Provider: v.Provider.copy(),
@@ -64,7 +64,7 @@ func (v *VPN) mergeWith(other VPN) {
v.Wireguard.mergeWith(other.Wireguard) v.Wireguard.mergeWith(other.Wireguard)
} }
func (v *VPN) overrideWith(other VPN) { func (v *VPN) OverrideWith(other VPN) {
v.Type = helpers.OverrideWithString(v.Type, other.Type) v.Type = helpers.OverrideWithString(v.Type, other.Type)
v.Provider.overrideWith(other.Provider) v.Provider.overrideWith(other.Provider)
v.OpenVPN.overrideWith(other.OpenVPN) v.OpenVPN.overrideWith(other.OpenVPN)

View File

@@ -15,10 +15,11 @@ func newHandler(ctx context.Context, logger infoWarner, logging bool,
unboundLooper DNSLoop, unboundLooper DNSLoop,
updaterLooper UpdaterLooper, updaterLooper UpdaterLooper,
publicIPLooper PublicIPLoop, publicIPLooper PublicIPLoop,
storage Storage,
) http.Handler { ) http.Handler {
handler := &handler{} handler := &handler{}
vpn := newVPNHandler(ctx, vpnLooper, logger) vpn := newVPNHandler(ctx, vpnLooper, storage, logger)
openvpn := newOpenvpnHandler(ctx, vpnLooper, pfGetter, logger) openvpn := newOpenvpnHandler(ctx, vpnLooper, pfGetter, logger)
dns := newDNSHandler(ctx, unboundLooper, logger) dns := newDNSHandler(ctx, unboundLooper, logger)
updater := newUpdaterHandler(ctx, updaterLooper, logger) updater := newUpdaterHandler(ctx, updaterLooper, logger)

View File

@@ -12,6 +12,7 @@ type VPNLooper interface {
ApplyStatus(ctx context.Context, status models.LoopStatus) ( ApplyStatus(ctx context.Context, status models.LoopStatus) (
outcome string, err error) outcome string, err error)
GetSettings() (settings settings.VPN) GetSettings() (settings settings.VPN)
SetSettings(ctx context.Context, settings settings.VPN) (outcome string)
} }
type DNSLoop interface { type DNSLoop interface {
@@ -27,3 +28,7 @@ type PortForwardedGetter interface {
type PublicIPLoop interface { type PublicIPLoop interface {
GetData() (data models.PublicIP) GetData() (data models.PublicIP)
} }
type Storage interface {
GetFilterChoices(provider string) models.FilterChoices
}

View File

@@ -11,9 +11,10 @@ import (
func New(ctx context.Context, address string, logEnabled bool, logger Logger, func New(ctx context.Context, address string, logEnabled bool, logger Logger,
buildInfo models.BuildInformation, openvpnLooper VPNLooper, buildInfo models.BuildInformation, openvpnLooper VPNLooper,
pfGetter PortForwardedGetter, unboundLooper DNSLoop, pfGetter PortForwardedGetter, unboundLooper DNSLoop,
updaterLooper UpdaterLooper, publicIPLooper PublicIPLoop) (server *httpserver.Server, err error) { updaterLooper UpdaterLooper, publicIPLooper PublicIPLoop, storage Storage) (
server *httpserver.Server, err error) {
handler := newHandler(ctx, logger, logEnabled, buildInfo, handler := newHandler(ctx, logger, logEnabled, buildInfo,
openvpnLooper, pfGetter, unboundLooper, updaterLooper, publicIPLooper) openvpnLooper, pfGetter, unboundLooper, updaterLooper, publicIPLooper, storage)
httpServerSettings := httpserver.Settings{ httpServerSettings := httpserver.Settings{
Address: address, Address: address,

View File

@@ -5,21 +5,25 @@ import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"strings" "strings"
"github.com/qdm12/gluetun/internal/configuration/settings"
) )
func newVPNHandler(ctx context.Context, looper VPNLooper, func newVPNHandler(ctx context.Context, looper VPNLooper,
w warner) http.Handler { storage Storage, w warner) http.Handler {
return &vpnHandler{ return &vpnHandler{
ctx: ctx, ctx: ctx,
looper: looper, looper: looper,
warner: w, storage: storage,
warner: w,
} }
} }
type vpnHandler struct { type vpnHandler struct {
ctx context.Context //nolint:containedctx ctx context.Context //nolint:containedctx
looper VPNLooper looper VPNLooper
warner warner storage Storage
warner warner
} }
func (h *vpnHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *vpnHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
@@ -38,6 +42,8 @@ func (h *vpnHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
switch r.Method { switch r.Method {
case http.MethodGet: case http.MethodGet:
h.getSettings(w) h.getSettings(w)
case http.MethodPut:
h.patchSettings(w, r)
default: default:
http.Error(w, "method "+r.Method+" not supported", http.StatusBadRequest) http.Error(w, "method "+r.Method+" not supported", http.StatusBadRequest)
} }
@@ -91,3 +97,32 @@ func (h *vpnHandler) getSettings(w http.ResponseWriter) {
return return
} }
} }
func (h *vpnHandler) patchSettings(w http.ResponseWriter, r *http.Request) {
var overrideSettings settings.VPN
decoder := json.NewDecoder(r.Body)
err := decoder.Decode(&overrideSettings)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
err = r.Body.Close()
if err != nil {
h.warner.Warn("closing body: " + err.Error())
}
updatedSettings := h.looper.GetSettings() // already copied
updatedSettings.OverrideWith(overrideSettings)
err = updatedSettings.Validate(h.storage)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
outcome := h.looper.SetSettings(h.ctx, updatedSettings)
_, err = w.Write([]byte(outcome))
if err != nil {
h.warner.Warn("writing response: " + err.Error())
}
}

View File

@@ -10,7 +10,7 @@ import (
func (s *State) GetSettings() (vpn settings.VPN) { func (s *State) GetSettings() (vpn settings.VPN) {
s.settingsMu.RLock() s.settingsMu.RLock()
vpn = s.vpn vpn = s.vpn.Copy()
s.settingsMu.RUnlock() s.settingsMu.RUnlock()
return vpn return vpn
} }