Browse Source

Fix #271: Return 405 instead of 404 when request method doesn't match the route

pull/289/head
Mayank Patel 8 years ago committed by Kamil Kisiel
parent
commit
a659b61323
  1. 33
      mux.go
  2. 39
      mux_test.go
  3. 8
      old_test.go
  4. 16
      route.go

33
mux.go

@ -13,6 +13,10 @@ import ( @@ -13,6 +13,10 @@ import (
"strings"
)
var (
ErrMethodMismatch = errors.New("method is not allowed")
)
// NewRouter returns a new router instance.
func NewRouter() *Router {
return &Router{namedRoutes: make(map[string]*Route), KeepContext: false}
@ -39,6 +43,10 @@ func NewRouter() *Router { @@ -39,6 +43,10 @@ func NewRouter() *Router {
type Router struct {
// Configurable Handler to be used when no route matches.
NotFoundHandler http.Handler
// Configurable Handler to be used when the request method does not match the route.
MethodNotAllowedHandler http.Handler
// Parent route, if this is a subrouter.
parent parentRoute
// Routes to be matched, in order.
@ -65,6 +73,11 @@ func (r *Router) Match(req *http.Request, match *RouteMatch) bool { @@ -65,6 +73,11 @@ func (r *Router) Match(req *http.Request, match *RouteMatch) bool {
}
}
if match.MatchErr == ErrMethodMismatch && r.MethodNotAllowedHandler != nil {
match.Handler = r.MethodNotAllowedHandler
return true
}
// Closest match for a router (includes sub-routers)
if r.NotFoundHandler != nil {
match.Handler = r.NotFoundHandler
@ -105,9 +118,15 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { @@ -105,9 +118,15 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
req = setVars(req, match.Vars)
req = setCurrentRoute(req, match.Route)
}
if handler == nil && match.MatchErr == ErrMethodMismatch {
handler = methodNotAllowedHandler()
}
if handler == nil {
handler = http.NotFoundHandler()
}
if !r.KeepContext {
defer contextClear(req)
}
@ -344,6 +363,11 @@ type RouteMatch struct { @@ -344,6 +363,11 @@ type RouteMatch struct {
Route *Route
Handler http.Handler
Vars map[string]string
// MatchErr is set to appropriate matching error
// It is set to ErrMethodMismatch if there is a mismatch in
// the request method and route method
MatchErr error
}
type contextKey int
@ -545,3 +569,12 @@ func matchMapWithRegex(toCheck map[string]*regexp.Regexp, toMatch map[string][]s @@ -545,3 +569,12 @@ func matchMapWithRegex(toCheck map[string]*regexp.Regexp, toMatch map[string][]s
}
return true
}
// methodNotAllowed replies to the request with an HTTP status code 405.
func methodNotAllowed(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusMethodNotAllowed)
}
// methodNotAllowedHandler returns a simple request handler
// that replies to each request with a status code 405.
func methodNotAllowedHandler() http.Handler { return http.HandlerFunc(methodNotAllowed) }

39
mux_test.go

@ -1871,3 +1871,42 @@ func newRequest(method, url string) *http.Request { @@ -1871,3 +1871,42 @@ func newRequest(method, url string) *http.Request {
}
return req
}
func TestNoMatchMethodErrorHandler(t *testing.T) {
func1 := func(w http.ResponseWriter, r *http.Request) {}
r := NewRouter()
r.HandleFunc("/", func1).Methods("GET", "POST")
req, _ := http.NewRequest("PUT", "http://localhost/", nil)
match := new(RouteMatch)
matched := r.Match(req, match)
if matched {
t.Error("Should not have matched route for methods")
}
if match.MatchErr != ErrMethodMismatch {
t.Error("Should get ErrMethodMismatch error")
}
resp := NewRecorder()
r.ServeHTTP(resp, req)
if resp.Code != 405 {
t.Errorf("Expecting code %v", 405)
}
//Add matching route now
r.HandleFunc("/", func1).Methods("PUT")
match = new(RouteMatch)
matched = r.Match(req, match)
if !matched {
t.Error("Should have matched route for methods")
}
if match.MatchErr != nil {
t.Error("Should not have any matching error. Found:", match.MatchErr)
}
}

8
old_test.go

@ -121,12 +121,7 @@ func TestRouteMatchers(t *testing.T) { @@ -121,12 +121,7 @@ func TestRouteMatchers(t *testing.T) {
var routeMatch RouteMatch
matched := router.Match(request, &routeMatch)
if matched != shouldMatch {
// Need better messages. :)
if matched {
t.Errorf("Should match.")
} else {
t.Errorf("Should not match.")
}
t.Errorf("Expected: %v\nGot: %v\nRequest: %v %v", shouldMatch, matched, request.Method, url)
}
if matched {
@ -188,7 +183,6 @@ func TestRouteMatchers(t *testing.T) { @@ -188,7 +183,6 @@ func TestRouteMatchers(t *testing.T) {
match(true)
// 2nd route --------------------------------------------------------------
// Everything match.
reset2()
match(true)

16
route.go

@ -52,12 +52,27 @@ func (r *Route) Match(req *http.Request, match *RouteMatch) bool { @@ -52,12 +52,27 @@ func (r *Route) Match(req *http.Request, match *RouteMatch) bool {
if r.buildOnly || r.err != nil {
return false
}
var matchErr error
// Match everything.
for _, m := range r.matchers {
if matched := m.Match(req, match); !matched {
if _, ok := m.(methodMatcher); ok {
matchErr = ErrMethodMismatch
continue
}
matchErr = nil
return false
}
}
if matchErr != nil {
match.MatchErr = matchErr
return false
}
match.MatchErr = nil
// Yay, we have a match. Let's collect some info about it.
if match.Route == nil {
match.Route = r
@ -68,6 +83,7 @@ func (r *Route) Match(req *http.Request, match *RouteMatch) bool { @@ -68,6 +83,7 @@ func (r *Route) Match(req *http.Request, match *RouteMatch) bool {
if match.Vars == nil {
match.Vars = make(map[string]string)
}
// Set variables.
if r.regexp != nil {
r.regexp.setMatch(req, match, r)

Loading…
Cancel
Save