diff options
| author | Grégoire Duchêne <gduchene@awhk.org> | 2022-06-18 12:52:57 +0100 |
|---|---|---|
| committer | Grégoire Duchêne <gduchene@awhk.org> | 2022-06-18 12:52:57 +0100 |
| commit | 41a47c757dca86ec0684b060bdd1d3c4d55cc81f (patch) | |
| tree | 949e038cb7dd2166c41cedc7fc70122855f6656b | |
| parent | d38a4ae585bb5061b264c667d65f2adf922934a7 (diff) | |
Add FilterHTTPMethod and FilteringHTTPHandler
| -rw-r--r-- | http.go | 48 | ||||
| -rw-r--r-- | http_test.go | 97 |
2 files changed, 145 insertions, 0 deletions
@@ -0,0 +1,48 @@ +// SPDX-FileCopyrightText: © 2022 Grégoire Duchêne <gduchene@awhk.org> +// SPDX-License-Identifier: ISC + +package core + +import ( + "net/http" + "sort" + "strings" +) + +// FilteringHTTPHandler returns a handler that will check that a request +// was not filtered before handing it over to the passed handler. +func FilteringHTTPHandler(handler http.Handler, filters ...HTTPFilterFunc) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + for _, filter := range filters { + if filter(w, req) { + return + } + } + handler.ServeHTTP(w, req) + }) +} + +// HTTPFilterFunc describes a filtering function for HTTP headers. The +// filtering function must return true if a request should be filtered +// and false otherwise. The filtering function may only call functions +// on the http.ResponseWriter or change the http.Request if a request is +// filtered. +type HTTPFilterFunc func(http.ResponseWriter, *http.Request) bool + +// FilterHTTPMethod is an HTTPFilterFunc that filters requests based on +// the HTTP methods passed. Requests that do not have a matching method +// will be filtered. +func FilterHTTPMethod(methods ...string) HTTPFilterFunc { + sort.Strings(methods) + allowed := strings.Join(methods, ", ") + return func(w http.ResponseWriter, req *http.Request) bool { + for _, method := range methods { + if method == req.Method { + return false + } + } + w.Header().Set("Allowed", allowed) + w.WriteHeader(http.StatusMethodNotAllowed) + return true + } +} diff --git a/http_test.go b/http_test.go new file mode 100644 index 0000000..a8c3cb5 --- /dev/null +++ b/http_test.go @@ -0,0 +1,97 @@ +// SPDX-FileCopyrightText: © 2022 Grégoire Duchêne <gduchene@awhk.org> +// SPDX-License-Identifier: ISC + +package core_test + +import ( + "net/http" + "net/http/httptest" + "testing" + + "go.awhk.org/core" +) + +func TestFilteringHTTPHandler(s *testing.T) { + t := core.T{T: s} + + handler := core.FilteringHTTPHandler( + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) }), + core.FilterHTTPMethod(http.MethodHead), + ) + for _, tc := range []struct { + name string + method string + + expHeader http.Header + expStatusCode int + }{ + { + name: "Success", + method: http.MethodHead, + + expHeader: http.Header{}, + expStatusCode: http.StatusOK, + }, + { + name: "WhenFiltered", + method: http.MethodGet, + + expHeader: http.Header{"Allowed": {"HEAD"}}, + expStatusCode: http.StatusMethodNotAllowed, + }, + } { + t.Run(tc.name, func(t *core.T) { + var ( + req = httptest.NewRequest(tc.method, "/", nil) + w = httptest.NewRecorder() + ) + handler.ServeHTTP(w, req) + + res := w.Result() + t.AssertEqual(tc.expHeader, res.Header) + t.AssertEqual(tc.expStatusCode, res.StatusCode) + }) + } +} + +func TestFilterHTTPMethod(s *testing.T) { + t := core.T{T: s} + + filter := core.FilterHTTPMethod(http.MethodPost, http.MethodGet) + for _, tc := range []struct { + name string + method string + + expAllowed string + expFiltered bool + expStatusCode int + }{ + { + name: "Success", + method: http.MethodPost, + + expFiltered: false, + expStatusCode: http.StatusOK, + }, + { + name: "WhenFiltered", + method: http.MethodHead, + + expAllowed: "GET, POST", + expFiltered: true, + expStatusCode: http.StatusMethodNotAllowed, + }, + } { + t.Run(tc.name, func(t *core.T) { + var ( + req = httptest.NewRequest(tc.method, "/", nil) + w = httptest.NewRecorder() + ) + t.AssertEqual(tc.expFiltered, filter(w, req)) + + res := w.Result() + t.AssertEqual(tc.expAllowed, res.Header.Get("Allowed")) + t.AssertEqual(tc.expStatusCode, res.StatusCode) + }) + } +} |
