From a659b61323b075cb38ad03aa43749e52eb4948e0 Mon Sep 17 00:00:00 2001 From: Mayank Patel Date: Wed, 30 Aug 2017 11:09:17 +0530 Subject: [PATCH] Fix #271: Return 405 instead of 404 when request method doesn't match the route --- mux.go | 33 +++++++++++++++++++++++++++++++++ mux_test.go | 39 +++++++++++++++++++++++++++++++++++++++ old_test.go | 8 +------- route.go | 16 ++++++++++++++++ 4 files changed, 89 insertions(+), 7 deletions(-) diff --git a/mux.go b/mux.go index aa19597..fb69196 100644 --- a/mux.go +++ b/mux.go @@ -13,6 +13,10 @@ import ( "strings" ) +var ( + ErrMethodMismatch = errors.New("method is not allowed") +) + // NewRouter returns a new router instance. func NewRouter() *Router { return &Router{namedRoutes: make(map[string]*Route), KeepContext: false} @@ -39,6 +43,10 @@ func NewRouter() *Router { type Router struct { // Configurable Handler to be used when no route matches. NotFoundHandler http.Handler + + // Configurable Handler to be used when the request method does not match the route. + MethodNotAllowedHandler http.Handler + // Parent route, if this is a subrouter. parent parentRoute // Routes to be matched, in order. @@ -65,6 +73,11 @@ func (r *Router) Match(req *http.Request, match *RouteMatch) bool { } } + if match.MatchErr == ErrMethodMismatch && r.MethodNotAllowedHandler != nil { + match.Handler = r.MethodNotAllowedHandler + return true + } + // Closest match for a router (includes sub-routers) if r.NotFoundHandler != nil { match.Handler = r.NotFoundHandler @@ -105,9 +118,15 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { req = setVars(req, match.Vars) req = setCurrentRoute(req, match.Route) } + + if handler == nil && match.MatchErr == ErrMethodMismatch { + handler = methodNotAllowedHandler() + } + if handler == nil { handler = http.NotFoundHandler() } + if !r.KeepContext { defer contextClear(req) } @@ -344,6 +363,11 @@ type RouteMatch struct { Route *Route Handler http.Handler Vars map[string]string + + // MatchErr is set to appropriate matching error + // It is set to ErrMethodMismatch if there is a mismatch in + // the request method and route method + MatchErr error } type contextKey int @@ -545,3 +569,12 @@ func matchMapWithRegex(toCheck map[string]*regexp.Regexp, toMatch map[string][]s } return true } + +// methodNotAllowed replies to the request with an HTTP status code 405. +func methodNotAllowed(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusMethodNotAllowed) +} + +// methodNotAllowedHandler returns a simple request handler +// that replies to each request with a status code 405. +func methodNotAllowedHandler() http.Handler { return http.HandlerFunc(methodNotAllowed) } diff --git a/mux_test.go b/mux_test.go index f6cfb44..413f13f 100644 --- a/mux_test.go +++ b/mux_test.go @@ -1871,3 +1871,42 @@ func newRequest(method, url string) *http.Request { } return req } + +func TestNoMatchMethodErrorHandler(t *testing.T) { + func1 := func(w http.ResponseWriter, r *http.Request) {} + + r := NewRouter() + r.HandleFunc("/", func1).Methods("GET", "POST") + + req, _ := http.NewRequest("PUT", "http://localhost/", nil) + match := new(RouteMatch) + matched := r.Match(req, match) + + if matched { + t.Error("Should not have matched route for methods") + } + + if match.MatchErr != ErrMethodMismatch { + t.Error("Should get ErrMethodMismatch error") + } + + resp := NewRecorder() + r.ServeHTTP(resp, req) + if resp.Code != 405 { + t.Errorf("Expecting code %v", 405) + } + + //Add matching route now + r.HandleFunc("/", func1).Methods("PUT") + + match = new(RouteMatch) + matched = r.Match(req, match) + + if !matched { + t.Error("Should have matched route for methods") + } + + if match.MatchErr != nil { + t.Error("Should not have any matching error. Found:", match.MatchErr) + } +} diff --git a/old_test.go b/old_test.go index 9bdc5e5..3751e47 100644 --- a/old_test.go +++ b/old_test.go @@ -121,12 +121,7 @@ func TestRouteMatchers(t *testing.T) { var routeMatch RouteMatch matched := router.Match(request, &routeMatch) if matched != shouldMatch { - // Need better messages. :) - if matched { - t.Errorf("Should match.") - } else { - t.Errorf("Should not match.") - } + t.Errorf("Expected: %v\nGot: %v\nRequest: %v %v", shouldMatch, matched, request.Method, url) } if matched { @@ -188,7 +183,6 @@ func TestRouteMatchers(t *testing.T) { match(true) // 2nd route -------------------------------------------------------------- - // Everything match. reset2() match(true) diff --git a/route.go b/route.go index 6d4a07a..6863adb 100644 --- a/route.go +++ b/route.go @@ -52,12 +52,27 @@ func (r *Route) Match(req *http.Request, match *RouteMatch) bool { if r.buildOnly || r.err != nil { return false } + + var matchErr error + // Match everything. for _, m := range r.matchers { if matched := m.Match(req, match); !matched { + if _, ok := m.(methodMatcher); ok { + matchErr = ErrMethodMismatch + continue + } + matchErr = nil return false } } + + if matchErr != nil { + match.MatchErr = matchErr + return false + } + + match.MatchErr = nil // Yay, we have a match. Let's collect some info about it. if match.Route == nil { match.Route = r @@ -68,6 +83,7 @@ func (r *Route) Match(req *http.Request, match *RouteMatch) bool { if match.Vars == nil { match.Vars = make(map[string]string) } + // Set variables. if r.regexp != nil { r.regexp.setMatch(req, match, r)