diff --git a/middleware.go b/middleware.go index cb51c56..03a9d9b 100644 --- a/middleware.go +++ b/middleware.go @@ -20,6 +20,27 @@ func (mw MiddlewareFunc) Middleware(handler http.Handler) http.Handler { return mw(handler) } +// NewMiddleware helper method which allows creating a middleware from a function which receives the response and request. +// The typical case would be passing a closure +func NewMiddleware(mw func(http.ResponseWriter, *http.Request)) MiddlewareFunc { + return func(handler http.Handler) http.Handler { + return &genericMW{ + process: func(w http.ResponseWriter, r *http.Request) { + mw(w, r) + handler.ServeHTTP(w, r) + }, + } + } +} + +// genericMW is for creating a middleware from a closure, it's used in the public method NewMiddleware +type genericMW struct { + process func(http.ResponseWriter, *http.Request) +} +func (instance *genericMW) ServeHTTP(w http.ResponseWriter, req *http.Request) { + instance.process(w, req) +} + // Use appends a MiddlewareFunc to the chain. Middleware can be used to intercept or otherwise modify requests and/or responses, and are executed in the order that they are applied to the Router. func (r *Router) Use(mwf ...MiddlewareFunc) { for _, fn := range mwf { diff --git a/middleware_test.go b/middleware_test.go index e9f0ef5..47730e3 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -3,6 +3,7 @@ package mux import ( "bytes" "net/http" + "net/http/httptest" "testing" ) @@ -563,3 +564,36 @@ func TestMiddlewareOnMultiSubrouter(t *testing.T) { } }) } + +func TestNewMiddleware(t *testing.T) { + + mwFuncCalls := 0 + + handler := &mockHandler{ + serverHttp: func(w http.ResponseWriter, r *http.Request) { + if mwFuncCalls != 1 { + t.Fatalf("Expected Handler to be called after mw run. However, the middleware run {%d} times before the handler.", mwFuncCalls) + } + mwFuncCalls++ + }, + } + + newHandler := NewMiddleware(func(http.ResponseWriter, *http.Request) { + mwFuncCalls++ + }).Middleware(handler) + + newHandler.ServeHTTP(&httptest.ResponseRecorder{}, &http.Request{}) + + if mwFuncCalls != 2 { + t.Fatalf("Expected mwFuncCalls to be 2, but got {%d}", mwFuncCalls) + } + +} + +type mockHandler struct { + serverHttp func(writer http.ResponseWriter, request *http.Request) +} + +func (d *mockHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + d.serverHttp(writer, request) +}