Browse Source

Create new helper method for creating middlewares from a closure/anonymous function

pull/647/head
Mariano Sosto 4 years ago
parent
commit
134837b3ce
  1. 21
      middleware.go
  2. 34
      middleware_test.go

21
middleware.go

@ -20,6 +20,27 @@ func (mw MiddlewareFunc) Middleware(handler http.Handler) http.Handler { @@ -20,6 +20,27 @@ func (mw MiddlewareFunc) Middleware(handler http.Handler) http.Handler {
return mw(handler)
}
// NewMiddleware helper method which allows creating a middleware from a function which receives the response and request.
// The typical case would be passing a closure
func NewMiddleware(mw func(http.ResponseWriter, *http.Request)) MiddlewareFunc {
return func(handler http.Handler) http.Handler {
return &genericMW{
process: func(w http.ResponseWriter, r *http.Request) {
mw(w, r)
handler.ServeHTTP(w, r)
},
}
}
}
// genericMW is for creating a middleware from a closure, it's used in the public method NewMiddleware
type genericMW struct {
process func(http.ResponseWriter, *http.Request)
}
func (instance *genericMW) ServeHTTP(w http.ResponseWriter, req *http.Request) {
instance.process(w, req)
}
// Use appends a MiddlewareFunc to the chain. Middleware can be used to intercept or otherwise modify requests and/or responses, and are executed in the order that they are applied to the Router.
func (r *Router) Use(mwf ...MiddlewareFunc) {
for _, fn := range mwf {

34
middleware_test.go

@ -3,6 +3,7 @@ package mux @@ -3,6 +3,7 @@ package mux
import (
"bytes"
"net/http"
"net/http/httptest"
"testing"
)
@ -563,3 +564,36 @@ func TestMiddlewareOnMultiSubrouter(t *testing.T) { @@ -563,3 +564,36 @@ func TestMiddlewareOnMultiSubrouter(t *testing.T) {
}
})
}
func TestNewMiddleware(t *testing.T) {
mwFuncCalls := 0
handler := &mockHandler{
serverHttp: func(w http.ResponseWriter, r *http.Request) {
if mwFuncCalls != 1 {
t.Fatalf("Expected Handler to be called after mw run. However, the middleware run {%d} times before the handler.", mwFuncCalls)
}
mwFuncCalls++
},
}
newHandler := NewMiddleware(func(http.ResponseWriter, *http.Request) {
mwFuncCalls++
}).Middleware(handler)
newHandler.ServeHTTP(&httptest.ResponseRecorder{}, &http.Request{})
if mwFuncCalls != 2 {
t.Fatalf("Expected mwFuncCalls to be 2, but got {%d}", mwFuncCalls)
}
}
type mockHandler struct {
serverHttp func(writer http.ResponseWriter, request *http.Request)
}
func (d *mockHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
d.serverHttp(writer, request)
}

Loading…
Cancel
Save