diff --git a/internal/server/handler.go b/internal/server/handler.go index 4b01aa3a..443088c0 100644 --- a/internal/server/handler.go +++ b/internal/server/handler.go @@ -18,13 +18,14 @@ func newHandler(ctx context.Context, logger infoWarner, logging bool, ) http.Handler { handler := &handler{} + vpn := newVPNHandler(ctx, vpnLooper, logger) openvpn := newOpenvpnHandler(ctx, vpnLooper, pfGetter, logger) dns := newDNSHandler(ctx, unboundLooper, logger) updater := newUpdaterHandler(ctx, updaterLooper, logger) publicip := newPublicIPHandler(publicIPLooper, logger) handler.v0 = newHandlerV0(ctx, logger, vpnLooper, unboundLooper, updaterLooper) - handler.v1 = newHandlerV1(logger, buildInfo, openvpn, dns, updater, publicip) + handler.v1 = newHandlerV1(logger, buildInfo, vpn, openvpn, dns, updater, publicip) handlerWithLog := withLogMiddleware(handler, logger, logging) handler.setLogEnabled = handlerWithLog.setEnabled diff --git a/internal/server/handlerv1.go b/internal/server/handlerv1.go index f2296eb3..99fc97ca 100644 --- a/internal/server/handlerv1.go +++ b/internal/server/handlerv1.go @@ -10,10 +10,11 @@ import ( ) func newHandlerV1(w warner, buildInfo models.BuildInformation, - openvpn, dns, updater, publicip http.Handler) http.Handler { + vpn, openvpn, dns, updater, publicip http.Handler) http.Handler { return &handlerV1{ warner: w, buildInfo: buildInfo, + vpn: vpn, openvpn: openvpn, dns: dns, updater: updater, @@ -24,6 +25,7 @@ func newHandlerV1(w warner, buildInfo models.BuildInformation, type handlerV1 struct { warner warner buildInfo models.BuildInformation + vpn http.Handler openvpn http.Handler dns http.Handler updater http.Handler @@ -34,6 +36,8 @@ func (h *handlerV1) ServeHTTP(w http.ResponseWriter, r *http.Request) { switch { case r.RequestURI == "/version" && r.Method == http.MethodGet: h.getVersion(w) + case strings.HasPrefix(r.RequestURI, "/vpn"): + h.vpn.ServeHTTP(w, r) case strings.HasPrefix(r.RequestURI, "/openvpn"): h.openvpn.ServeHTTP(w, r) case strings.HasPrefix(r.RequestURI, "/dns"): diff --git a/internal/server/helpers.go b/internal/server/helpers.go deleted file mode 100644 index b32f2f5a..00000000 --- a/internal/server/helpers.go +++ /dev/null @@ -1,3 +0,0 @@ -package server - -func stringPtr(s string) *string { return &s } diff --git a/internal/server/vpn.go b/internal/server/vpn.go new file mode 100644 index 00000000..9c452c45 --- /dev/null +++ b/internal/server/vpn.go @@ -0,0 +1,93 @@ +package server + +import ( + "context" + "encoding/json" + "net/http" + "strings" +) + +func newVPNHandler(ctx context.Context, looper VPNLooper, + w warner) http.Handler { + return &vpnHandler{ + ctx: ctx, + looper: looper, + warner: w, + } +} + +type vpnHandler struct { + ctx context.Context //nolint:containedctx + looper VPNLooper + warner warner +} + +func (h *vpnHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + r.RequestURI = strings.TrimPrefix(r.RequestURI, "/vpn") + switch r.RequestURI { + case "/status": + switch r.Method { + case http.MethodGet: + h.getStatus(w) + case http.MethodPut: + h.setStatus(w, r) + default: + http.Error(w, "method "+r.Method+" not supported", http.StatusBadRequest) + } + case "/settings": + switch r.Method { + case http.MethodGet: + h.getSettings(w) + default: + http.Error(w, "method "+r.Method+" not supported", http.StatusBadRequest) + } + default: + http.Error(w, "route "+r.RequestURI+" not supported", http.StatusBadRequest) + } +} + +func (h *vpnHandler) getStatus(w http.ResponseWriter) { + status := h.looper.GetStatus() + encoder := json.NewEncoder(w) + data := statusWrapper{Status: string(status)} + if err := encoder.Encode(data); err != nil { + h.warner.Warn(err.Error()) + w.WriteHeader(http.StatusInternalServerError) + return + } +} + +func (h *vpnHandler) setStatus(w http.ResponseWriter, r *http.Request) { + decoder := json.NewDecoder(r.Body) + var data statusWrapper + if err := decoder.Decode(&data); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + status, err := data.getStatus() + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + outcome, err := h.looper.ApplyStatus(h.ctx, status) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + encoder := json.NewEncoder(w) + if err := encoder.Encode(outcomeWrapper{Outcome: outcome}); err != nil { + h.warner.Warn(err.Error()) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } +} + +func (h *vpnHandler) getSettings(w http.ResponseWriter) { + settings := h.looper.GetSettings() + encoder := json.NewEncoder(w) + if err := encoder.Encode(settings); err != nil { + h.warner.Warn(err.Error()) + w.WriteHeader(http.StatusInternalServerError) + return + } +}