Browse Source

Add CORSMethodMiddleware (#366)

CORSMethodMiddleware sets the Access-Control-Allow-Methods response header
on a request, by matching routes based only on paths. It also handles
OPTIONS requests, by settings Access-Control-Allow-Methods, and then
returning without calling the next HTTP handler.
pull/374/head
Franklin Harding 8 years ago committed by Matt Silverlock
parent
commit
5e55a4adb8
  1. 44
      middleware.go
  2. 41
      middleware_test.go
  3. 8
      mux_test.go

44
middleware.go

@ -1,6 +1,9 @@
package mux package mux
import "net/http" import (
"net/http"
"strings"
)
// MiddlewareFunc is a function which receives an http.Handler and returns another http.Handler. // MiddlewareFunc is a function which receives an http.Handler and returns another http.Handler.
// Typically, the returned handler is a closure which does something with the http.ResponseWriter and http.Request passed // Typically, the returned handler is a closure which does something with the http.ResponseWriter and http.Request passed
@ -28,3 +31,42 @@ func (r *Router) Use(mwf ...MiddlewareFunc) {
func (r *Router) useInterface(mw middleware) { func (r *Router) useInterface(mw middleware) {
r.middlewares = append(r.middlewares, mw) r.middlewares = append(r.middlewares, mw)
} }
// CORSMethodMiddleware sets the Access-Control-Allow-Methods response header
// on a request, by matching routes based only on paths. It also handles
// OPTIONS requests, by settings Access-Control-Allow-Methods, and then
// returning without calling the next http handler.
func CORSMethodMiddleware(r *Router) MiddlewareFunc {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
var allMethods []string
err := r.Walk(func(route *Route, _ *Router, _ []*Route) error {
for _, m := range route.matchers {
if _, ok := m.(*routeRegexp); ok {
if m.Match(req, &RouteMatch{}) {
methods, err := route.GetMethods()
if err != nil {
return err
}
allMethods = append(allMethods, methods...)
}
break
}
}
return nil
})
if err == nil {
w.Header().Set("Access-Control-Allow-Methods", strings.Join(append(allMethods, "OPTIONS"), ","))
if req.Method == "OPTIONS" {
return
}
}
next.ServeHTTP(w, req)
})
}
}

41
middleware_test.go

@ -3,6 +3,7 @@ package mux
import ( import (
"bytes" "bytes"
"net/http" "net/http"
"net/http/httptest"
"testing" "testing"
) )
@ -334,3 +335,43 @@ func TestMiddlewareMethodMismatchSubrouter(t *testing.T) {
t.Fatal("Middleware was called for a method mismatch") t.Fatal("Middleware was called for a method mismatch")
} }
} }
func TestCORSMethodMiddleware(t *testing.T) {
router := NewRouter()
cases := []struct {
path string
response string
method string
testURL string
expectedAllowedMethods string
}{
{"/g/{o}", "a", "POST", "/g/asdf", "POST,PUT,GET,OPTIONS"},
{"/g/{o}", "b", "PUT", "/g/bla", "POST,PUT,GET,OPTIONS"},
{"/g/{o}", "c", "GET", "/g/orilla", "POST,PUT,GET,OPTIONS"},
{"/g", "d", "POST", "/g", "POST,OPTIONS"},
}
for _, tt := range cases {
router.HandleFunc(tt.path, stringHandler(tt.response)).Methods(tt.method)
}
router.Use(CORSMethodMiddleware(router))
for _, tt := range cases {
rr := httptest.NewRecorder()
req := newRequest(tt.method, tt.testURL)
router.ServeHTTP(rr, req)
if rr.Body.String() != tt.response {
t.Errorf("Expected body '%s', found '%s'", tt.response, rr.Body.String())
}
allowedMethods := rr.HeaderMap.Get("Access-Control-Allow-Methods")
if allowedMethods != tt.expectedAllowedMethods {
t.Errorf("Expected Access-Control-Allow-Methods '%s', found '%s'", tt.expectedAllowedMethods, allowedMethods)
}
}
}

8
mux_test.go

@ -2315,6 +2315,14 @@ func stringMapEqual(m1, m2 map[string]string) bool {
return true return true
} }
// stringHandler returns a handler func that writes a message 's' to the
// http.ResponseWriter.
func stringHandler(s string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(s))
}
}
// newRequest is a helper function to create a new request with a method and url. // newRequest is a helper function to create a new request with a method and url.
// The request returned is a 'server' request as opposed to a 'client' one through // The request returned is a 'server' request as opposed to a 'client' one through
// simulated write onto the wire and read off of the wire. // simulated write onto the wire and read off of the wire.

Loading…
Cancel
Save