From 38b9e10ecf56377ccae8b2231dcaffaeec52e142 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gr=C3=A9goire=20Duch=C3=AAne?= Date: Sun, 7 Nov 2021 10:14:19 +0000 Subject: [PATCH] Set the Allow header when ErrMethodMismatch is set MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This changes RouteMatch so it contains a slice of the methods that would have been accepted. That slice is then used to populate the Allow header accordingly. This makes the default behavior of mux when returning 405 Method Not Allowed compliant with [RFC 7231ยง6.5.5][RFC7231]. [RFC7231]: https://datatracker.ietf.org/doc/html/rfc7231#section-6.5.5 --- mux.go | 19 ++++++++++++------- mux_test.go | 7 ++++++- route.go | 7 ++++++- 3 files changed, 24 insertions(+), 9 deletions(-) diff --git a/mux.go b/mux.go index f126a60..5807eaf 100644 --- a/mux.go +++ b/mux.go @@ -11,6 +11,7 @@ import ( "net/http" "path" "regexp" + "strings" ) var ( @@ -202,7 +203,7 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { } if handler == nil && match.MatchErr == ErrMethodMismatch { - handler = methodNotAllowedHandler() + handler = methodNotAllowedHandler(match.AllowedMethods) } if handler == nil { @@ -417,6 +418,10 @@ type RouteMatch struct { // It is set to ErrMethodMismatch if there is a mismatch in // the request method and route method MatchErr error + + // AllowedMethods contains the list of methods allowed by a route when + // MatchErr is set to ErrMethodMismatch. + AllowedMethods []string } type contextKey int @@ -598,11 +603,11 @@ func matchMapWithRegex(toCheck map[string]*regexp.Regexp, toMatch map[string][]s return true } -// methodNotAllowed replies to the request with an HTTP status code 405. -func methodNotAllowed(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusMethodNotAllowed) -} - // methodNotAllowedHandler returns a simple request handler // that replies to each request with a status code 405. -func methodNotAllowedHandler() http.Handler { return http.HandlerFunc(methodNotAllowed) } +func methodNotAllowedHandler(allowed []string) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Allow", strings.Join(allowed, ",")) + w.WriteHeader(http.StatusMethodNotAllowed) + }) +} diff --git a/mux_test.go b/mux_test.go index 2d8d2b3..992926d 100644 --- a/mux_test.go +++ b/mux_test.go @@ -2052,6 +2052,9 @@ func TestNoMatchMethodErrorHandler(t *testing.T) { if resp.Code != http.StatusMethodNotAllowed { t.Errorf("Expecting code %v", 405) } + if hdr := resp.Header().Get("Allow"); hdr != "GET,POST" { + t.Errorf(`Expected Allow header to be "GET,POST" (got %q)`, hdr) + } // Add matching route r.HandleFunc("/", func1).Methods("PUT") @@ -2721,7 +2724,6 @@ func TestMethodNotAllowed(t *testing.T) { handler := func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) } router := NewRouter() router.HandleFunc("/thing", handler).Methods(http.MethodGet) - router.HandleFunc("/something", handler).Methods(http.MethodGet) w := NewRecorder() req := newRequest(http.MethodPut, "/thing") @@ -2731,6 +2733,9 @@ func TestMethodNotAllowed(t *testing.T) { if w.Code != http.StatusMethodNotAllowed { t.Fatalf("Expected status code 405 (got %d)", w.Code) } + if hdr := w.Header().Get("Allow"); hdr != http.MethodGet { + t.Fatalf(`Expected Allow header to be "GET" (got %q)`, hdr) + } } type customMethodNotAllowedHandler struct { diff --git a/route.go b/route.go index 750afe5..c338727 100644 --- a/route.go +++ b/route.go @@ -44,12 +44,14 @@ func (r *Route) Match(req *http.Request, match *RouteMatch) bool { } var matchErr error + var allowedMethods []string // Match everything. for _, m := range r.matchers { if matched := m.Match(req, match); !matched { - if _, ok := m.(methodMatcher); ok { + if m, ok := m.(methodMatcher); ok { matchErr = ErrMethodMismatch + allowedMethods = m continue } @@ -71,6 +73,9 @@ func (r *Route) Match(req *http.Request, match *RouteMatch) bool { if matchErr != nil { match.MatchErr = matchErr + if matchErr == ErrMethodMismatch { + match.AllowedMethods = allowedMethods + } return false }