Browse Source
* mux.Router now has a `Use` method that allows you to add middleware to request processing.pull/330/head v1.6.1
5 changed files with 519 additions and 0 deletions
@ -0,0 +1,28 @@ |
|||||||
|
package mux |
||||||
|
|
||||||
|
import "net/http" |
||||||
|
|
||||||
|
// MiddlewareFunc is a function which receives an http.Handler and returns another http.Handler.
|
||||||
|
// Typically, the returned handler is a closure which does something with the http.ResponseWriter and http.Request passed
|
||||||
|
// to it, and then calls the handler passed as parameter to the MiddlewareFunc.
|
||||||
|
type MiddlewareFunc func(http.Handler) http.Handler |
||||||
|
|
||||||
|
// middleware interface is anything which implements a MiddlewareFunc named Middleware.
|
||||||
|
type middleware interface { |
||||||
|
Middleware(handler http.Handler) http.Handler |
||||||
|
} |
||||||
|
|
||||||
|
// MiddlewareFunc also implements the middleware interface.
|
||||||
|
func (mw MiddlewareFunc) Middleware(handler http.Handler) http.Handler { |
||||||
|
return mw(handler) |
||||||
|
} |
||||||
|
|
||||||
|
// Use appends a MiddlewareFunc to the chain. Middleware can be used to intercept or otherwise modify requests and/or responses, and are executed in the order that they are applied to the Router.
|
||||||
|
func (r *Router) Use(mwf MiddlewareFunc) { |
||||||
|
r.middlewares = append(r.middlewares, mwf) |
||||||
|
} |
||||||
|
|
||||||
|
// useInterface appends a middleware to the chain. Middleware can be used to intercept or otherwise modify requests and/or responses, and are executed in the order that they are applied to the Router.
|
||||||
|
func (r *Router) useInterface(mw middleware) { |
||||||
|
r.middlewares = append(r.middlewares, mw) |
||||||
|
} |
||||||
@ -0,0 +1,336 @@ |
|||||||
|
package mux |
||||||
|
|
||||||
|
import ( |
||||||
|
"bytes" |
||||||
|
"net/http" |
||||||
|
"testing" |
||||||
|
) |
||||||
|
|
||||||
|
type testMiddleware struct { |
||||||
|
timesCalled uint |
||||||
|
} |
||||||
|
|
||||||
|
func (tm *testMiddleware) Middleware(h http.Handler) http.Handler { |
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
||||||
|
tm.timesCalled++ |
||||||
|
h.ServeHTTP(w, r) |
||||||
|
}) |
||||||
|
} |
||||||
|
|
||||||
|
func dummyHandler(w http.ResponseWriter, r *http.Request) {} |
||||||
|
|
||||||
|
func TestMiddlewareAdd(t *testing.T) { |
||||||
|
router := NewRouter() |
||||||
|
router.HandleFunc("/", dummyHandler).Methods("GET") |
||||||
|
|
||||||
|
mw := &testMiddleware{} |
||||||
|
|
||||||
|
router.useInterface(mw) |
||||||
|
if len(router.middlewares) != 1 || router.middlewares[0] != mw { |
||||||
|
t.Fatal("Middleware was not added correctly") |
||||||
|
} |
||||||
|
|
||||||
|
router.Use(mw.Middleware) |
||||||
|
if len(router.middlewares) != 2 { |
||||||
|
t.Fatal("MiddlewareFunc method was not added correctly") |
||||||
|
} |
||||||
|
|
||||||
|
banalMw := func(handler http.Handler) http.Handler { |
||||||
|
return handler |
||||||
|
} |
||||||
|
router.Use(banalMw) |
||||||
|
if len(router.middlewares) != 3 { |
||||||
|
t.Fatal("MiddlewareFunc method was not added correctly") |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestMiddleware(t *testing.T) { |
||||||
|
router := NewRouter() |
||||||
|
router.HandleFunc("/", dummyHandler).Methods("GET") |
||||||
|
|
||||||
|
mw := &testMiddleware{} |
||||||
|
router.useInterface(mw) |
||||||
|
|
||||||
|
rw := NewRecorder() |
||||||
|
req := newRequest("GET", "/") |
||||||
|
|
||||||
|
// Test regular middleware call
|
||||||
|
router.ServeHTTP(rw, req) |
||||||
|
if mw.timesCalled != 1 { |
||||||
|
t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled) |
||||||
|
} |
||||||
|
|
||||||
|
// Middleware should not be called for 404
|
||||||
|
req = newRequest("GET", "/not/found") |
||||||
|
router.ServeHTTP(rw, req) |
||||||
|
if mw.timesCalled != 1 { |
||||||
|
t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled) |
||||||
|
} |
||||||
|
|
||||||
|
// Middleware should not be called if there is a method mismatch
|
||||||
|
req = newRequest("POST", "/") |
||||||
|
router.ServeHTTP(rw, req) |
||||||
|
if mw.timesCalled != 1 { |
||||||
|
t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled) |
||||||
|
} |
||||||
|
|
||||||
|
// Add the middleware again as function
|
||||||
|
router.Use(mw.Middleware) |
||||||
|
req = newRequest("GET", "/") |
||||||
|
router.ServeHTTP(rw, req) |
||||||
|
if mw.timesCalled != 3 { |
||||||
|
t.Fatalf("Expected %d calls, but got only %d", 3, mw.timesCalled) |
||||||
|
} |
||||||
|
|
||||||
|
} |
||||||
|
|
||||||
|
func TestMiddlewareSubrouter(t *testing.T) { |
||||||
|
router := NewRouter() |
||||||
|
router.HandleFunc("/", dummyHandler).Methods("GET") |
||||||
|
|
||||||
|
subrouter := router.PathPrefix("/sub").Subrouter() |
||||||
|
subrouter.HandleFunc("/x", dummyHandler).Methods("GET") |
||||||
|
|
||||||
|
mw := &testMiddleware{} |
||||||
|
subrouter.useInterface(mw) |
||||||
|
|
||||||
|
rw := NewRecorder() |
||||||
|
req := newRequest("GET", "/") |
||||||
|
|
||||||
|
router.ServeHTTP(rw, req) |
||||||
|
if mw.timesCalled != 0 { |
||||||
|
t.Fatalf("Expected %d calls, but got only %d", 0, mw.timesCalled) |
||||||
|
} |
||||||
|
|
||||||
|
req = newRequest("GET", "/sub/") |
||||||
|
router.ServeHTTP(rw, req) |
||||||
|
if mw.timesCalled != 0 { |
||||||
|
t.Fatalf("Expected %d calls, but got only %d", 0, mw.timesCalled) |
||||||
|
} |
||||||
|
|
||||||
|
req = newRequest("GET", "/sub/x") |
||||||
|
router.ServeHTTP(rw, req) |
||||||
|
if mw.timesCalled != 1 { |
||||||
|
t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled) |
||||||
|
} |
||||||
|
|
||||||
|
req = newRequest("GET", "/sub/not/found") |
||||||
|
router.ServeHTTP(rw, req) |
||||||
|
if mw.timesCalled != 1 { |
||||||
|
t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled) |
||||||
|
} |
||||||
|
|
||||||
|
router.useInterface(mw) |
||||||
|
|
||||||
|
req = newRequest("GET", "/") |
||||||
|
router.ServeHTTP(rw, req) |
||||||
|
if mw.timesCalled != 2 { |
||||||
|
t.Fatalf("Expected %d calls, but got only %d", 2, mw.timesCalled) |
||||||
|
} |
||||||
|
|
||||||
|
req = newRequest("GET", "/sub/x") |
||||||
|
router.ServeHTTP(rw, req) |
||||||
|
if mw.timesCalled != 4 { |
||||||
|
t.Fatalf("Expected %d calls, but got only %d", 4, mw.timesCalled) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestMiddlewareExecution(t *testing.T) { |
||||||
|
mwStr := []byte("Middleware\n") |
||||||
|
handlerStr := []byte("Logic\n") |
||||||
|
|
||||||
|
router := NewRouter() |
||||||
|
router.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) { |
||||||
|
w.Write(handlerStr) |
||||||
|
}) |
||||||
|
|
||||||
|
rw := NewRecorder() |
||||||
|
req := newRequest("GET", "/") |
||||||
|
|
||||||
|
// Test handler-only call
|
||||||
|
router.ServeHTTP(rw, req) |
||||||
|
|
||||||
|
if bytes.Compare(rw.Body.Bytes(), handlerStr) != 0 { |
||||||
|
t.Fatal("Handler response is not what it should be") |
||||||
|
} |
||||||
|
|
||||||
|
// Test middleware call
|
||||||
|
rw = NewRecorder() |
||||||
|
|
||||||
|
router.Use(func(h http.Handler) http.Handler { |
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
||||||
|
w.Write(mwStr) |
||||||
|
h.ServeHTTP(w, r) |
||||||
|
}) |
||||||
|
}) |
||||||
|
|
||||||
|
router.ServeHTTP(rw, req) |
||||||
|
if bytes.Compare(rw.Body.Bytes(), append(mwStr, handlerStr...)) != 0 { |
||||||
|
t.Fatal("Middleware + handler response is not what it should be") |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestMiddlewareNotFound(t *testing.T) { |
||||||
|
mwStr := []byte("Middleware\n") |
||||||
|
handlerStr := []byte("Logic\n") |
||||||
|
|
||||||
|
router := NewRouter() |
||||||
|
router.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) { |
||||||
|
w.Write(handlerStr) |
||||||
|
}) |
||||||
|
router.Use(func(h http.Handler) http.Handler { |
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
||||||
|
w.Write(mwStr) |
||||||
|
h.ServeHTTP(w, r) |
||||||
|
}) |
||||||
|
}) |
||||||
|
|
||||||
|
// Test not found call with default handler
|
||||||
|
rw := NewRecorder() |
||||||
|
req := newRequest("GET", "/notfound") |
||||||
|
|
||||||
|
router.ServeHTTP(rw, req) |
||||||
|
if bytes.Contains(rw.Body.Bytes(), mwStr) { |
||||||
|
t.Fatal("Middleware was called for a 404") |
||||||
|
} |
||||||
|
|
||||||
|
// Test not found call with custom handler
|
||||||
|
rw = NewRecorder() |
||||||
|
req = newRequest("GET", "/notfound") |
||||||
|
|
||||||
|
router.NotFoundHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { |
||||||
|
rw.Write([]byte("Custom 404 handler")) |
||||||
|
}) |
||||||
|
router.ServeHTTP(rw, req) |
||||||
|
|
||||||
|
if bytes.Contains(rw.Body.Bytes(), mwStr) { |
||||||
|
t.Fatal("Middleware was called for a custom 404") |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestMiddlewareMethodMismatch(t *testing.T) { |
||||||
|
mwStr := []byte("Middleware\n") |
||||||
|
handlerStr := []byte("Logic\n") |
||||||
|
|
||||||
|
router := NewRouter() |
||||||
|
router.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) { |
||||||
|
w.Write(handlerStr) |
||||||
|
}).Methods("GET") |
||||||
|
|
||||||
|
router.Use(func(h http.Handler) http.Handler { |
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
||||||
|
w.Write(mwStr) |
||||||
|
h.ServeHTTP(w, r) |
||||||
|
}) |
||||||
|
}) |
||||||
|
|
||||||
|
// Test method mismatch
|
||||||
|
rw := NewRecorder() |
||||||
|
req := newRequest("POST", "/") |
||||||
|
|
||||||
|
router.ServeHTTP(rw, req) |
||||||
|
if bytes.Contains(rw.Body.Bytes(), mwStr) { |
||||||
|
t.Fatal("Middleware was called for a method mismatch") |
||||||
|
} |
||||||
|
|
||||||
|
// Test not found call
|
||||||
|
rw = NewRecorder() |
||||||
|
req = newRequest("POST", "/") |
||||||
|
|
||||||
|
router.MethodNotAllowedHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { |
||||||
|
rw.Write([]byte("Method not allowed")) |
||||||
|
}) |
||||||
|
router.ServeHTTP(rw, req) |
||||||
|
|
||||||
|
if bytes.Contains(rw.Body.Bytes(), mwStr) { |
||||||
|
t.Fatal("Middleware was called for a method mismatch") |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestMiddlewareNotFoundSubrouter(t *testing.T) { |
||||||
|
mwStr := []byte("Middleware\n") |
||||||
|
handlerStr := []byte("Logic\n") |
||||||
|
|
||||||
|
router := NewRouter() |
||||||
|
router.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) { |
||||||
|
w.Write(handlerStr) |
||||||
|
}) |
||||||
|
|
||||||
|
subrouter := router.PathPrefix("/sub/").Subrouter() |
||||||
|
subrouter.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) { |
||||||
|
w.Write(handlerStr) |
||||||
|
}) |
||||||
|
|
||||||
|
router.Use(func(h http.Handler) http.Handler { |
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
||||||
|
w.Write(mwStr) |
||||||
|
h.ServeHTTP(w, r) |
||||||
|
}) |
||||||
|
}) |
||||||
|
|
||||||
|
// Test not found call for default handler
|
||||||
|
rw := NewRecorder() |
||||||
|
req := newRequest("GET", "/sub/notfound") |
||||||
|
|
||||||
|
router.ServeHTTP(rw, req) |
||||||
|
if bytes.Contains(rw.Body.Bytes(), mwStr) { |
||||||
|
t.Fatal("Middleware was called for a 404") |
||||||
|
} |
||||||
|
|
||||||
|
// Test not found call with custom handler
|
||||||
|
rw = NewRecorder() |
||||||
|
req = newRequest("GET", "/sub/notfound") |
||||||
|
|
||||||
|
subrouter.NotFoundHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { |
||||||
|
rw.Write([]byte("Custom 404 handler")) |
||||||
|
}) |
||||||
|
router.ServeHTTP(rw, req) |
||||||
|
|
||||||
|
if bytes.Contains(rw.Body.Bytes(), mwStr) { |
||||||
|
t.Fatal("Middleware was called for a custom 404") |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestMiddlewareMethodMismatchSubrouter(t *testing.T) { |
||||||
|
mwStr := []byte("Middleware\n") |
||||||
|
handlerStr := []byte("Logic\n") |
||||||
|
|
||||||
|
router := NewRouter() |
||||||
|
router.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) { |
||||||
|
w.Write(handlerStr) |
||||||
|
}) |
||||||
|
|
||||||
|
subrouter := router.PathPrefix("/sub/").Subrouter() |
||||||
|
subrouter.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) { |
||||||
|
w.Write(handlerStr) |
||||||
|
}).Methods("GET") |
||||||
|
|
||||||
|
router.Use(func(h http.Handler) http.Handler { |
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
||||||
|
w.Write(mwStr) |
||||||
|
h.ServeHTTP(w, r) |
||||||
|
}) |
||||||
|
}) |
||||||
|
|
||||||
|
// Test method mismatch without custom handler
|
||||||
|
rw := NewRecorder() |
||||||
|
req := newRequest("POST", "/sub/") |
||||||
|
|
||||||
|
router.ServeHTTP(rw, req) |
||||||
|
if bytes.Contains(rw.Body.Bytes(), mwStr) { |
||||||
|
t.Fatal("Middleware was called for a method mismatch") |
||||||
|
} |
||||||
|
|
||||||
|
// Test method mismatch with custom handler
|
||||||
|
rw = NewRecorder() |
||||||
|
req = newRequest("POST", "/sub/") |
||||||
|
|
||||||
|
router.MethodNotAllowedHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { |
||||||
|
rw.Write([]byte("Method not allowed")) |
||||||
|
}) |
||||||
|
router.ServeHTTP(rw, req) |
||||||
|
|
||||||
|
if bytes.Contains(rw.Body.Bytes(), mwStr) { |
||||||
|
t.Fatal("Middleware was called for a method mismatch") |
||||||
|
} |
||||||
|
} |
||||||
Loading…
Reference in new issue