aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGrégoire Duchêne <gduchene@awhk.org>2022-06-18 12:52:57 +0100
committerGrégoire Duchêne <gduchene@awhk.org>2022-06-18 12:52:57 +0100
commit41a47c757dca86ec0684b060bdd1d3c4d55cc81f (patch)
tree949e038cb7dd2166c41cedc7fc70122855f6656b
parentd38a4ae585bb5061b264c667d65f2adf922934a7 (diff)
Add FilterHTTPMethod and FilteringHTTPHandler
-rw-r--r--http.go48
-rw-r--r--http_test.go97
2 files changed, 145 insertions, 0 deletions
diff --git a/http.go b/http.go
new file mode 100644
index 0000000..83a6830
--- /dev/null
+++ b/http.go
@@ -0,0 +1,48 @@
+// SPDX-FileCopyrightText: © 2022 Grégoire Duchêne <gduchene@awhk.org>
+// SPDX-License-Identifier: ISC
+
+package core
+
+import (
+ "net/http"
+ "sort"
+ "strings"
+)
+
+// FilteringHTTPHandler returns a handler that will check that a request
+// was not filtered before handing it over to the passed handler.
+func FilteringHTTPHandler(handler http.Handler, filters ...HTTPFilterFunc) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
+ for _, filter := range filters {
+ if filter(w, req) {
+ return
+ }
+ }
+ handler.ServeHTTP(w, req)
+ })
+}
+
+// HTTPFilterFunc describes a filtering function for HTTP headers. The
+// filtering function must return true if a request should be filtered
+// and false otherwise. The filtering function may only call functions
+// on the http.ResponseWriter or change the http.Request if a request is
+// filtered.
+type HTTPFilterFunc func(http.ResponseWriter, *http.Request) bool
+
+// FilterHTTPMethod is an HTTPFilterFunc that filters requests based on
+// the HTTP methods passed. Requests that do not have a matching method
+// will be filtered.
+func FilterHTTPMethod(methods ...string) HTTPFilterFunc {
+ sort.Strings(methods)
+ allowed := strings.Join(methods, ", ")
+ return func(w http.ResponseWriter, req *http.Request) bool {
+ for _, method := range methods {
+ if method == req.Method {
+ return false
+ }
+ }
+ w.Header().Set("Allowed", allowed)
+ w.WriteHeader(http.StatusMethodNotAllowed)
+ return true
+ }
+}
diff --git a/http_test.go b/http_test.go
new file mode 100644
index 0000000..a8c3cb5
--- /dev/null
+++ b/http_test.go
@@ -0,0 +1,97 @@
+// SPDX-FileCopyrightText: © 2022 Grégoire Duchêne <gduchene@awhk.org>
+// SPDX-License-Identifier: ISC
+
+package core_test
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "go.awhk.org/core"
+)
+
+func TestFilteringHTTPHandler(s *testing.T) {
+ t := core.T{T: s}
+
+ handler := core.FilteringHTTPHandler(
+ http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) }),
+ core.FilterHTTPMethod(http.MethodHead),
+ )
+ for _, tc := range []struct {
+ name string
+ method string
+
+ expHeader http.Header
+ expStatusCode int
+ }{
+ {
+ name: "Success",
+ method: http.MethodHead,
+
+ expHeader: http.Header{},
+ expStatusCode: http.StatusOK,
+ },
+ {
+ name: "WhenFiltered",
+ method: http.MethodGet,
+
+ expHeader: http.Header{"Allowed": {"HEAD"}},
+ expStatusCode: http.StatusMethodNotAllowed,
+ },
+ } {
+ t.Run(tc.name, func(t *core.T) {
+ var (
+ req = httptest.NewRequest(tc.method, "/", nil)
+ w = httptest.NewRecorder()
+ )
+ handler.ServeHTTP(w, req)
+
+ res := w.Result()
+ t.AssertEqual(tc.expHeader, res.Header)
+ t.AssertEqual(tc.expStatusCode, res.StatusCode)
+ })
+ }
+}
+
+func TestFilterHTTPMethod(s *testing.T) {
+ t := core.T{T: s}
+
+ filter := core.FilterHTTPMethod(http.MethodPost, http.MethodGet)
+ for _, tc := range []struct {
+ name string
+ method string
+
+ expAllowed string
+ expFiltered bool
+ expStatusCode int
+ }{
+ {
+ name: "Success",
+ method: http.MethodPost,
+
+ expFiltered: false,
+ expStatusCode: http.StatusOK,
+ },
+ {
+ name: "WhenFiltered",
+ method: http.MethodHead,
+
+ expAllowed: "GET, POST",
+ expFiltered: true,
+ expStatusCode: http.StatusMethodNotAllowed,
+ },
+ } {
+ t.Run(tc.name, func(t *core.T) {
+ var (
+ req = httptest.NewRequest(tc.method, "/", nil)
+ w = httptest.NewRecorder()
+ )
+ t.AssertEqual(tc.expFiltered, filter(w, req))
+
+ res := w.Result()
+ t.AssertEqual(tc.expAllowed, res.Header.Get("Allowed"))
+ t.AssertEqual(tc.expStatusCode, res.StatusCode)
+ })
+ }
+}