Browse Source

Merge 21f221575c into 3cf0d013e5

pull/647/merge
Mariano 4 years ago committed by GitHub
parent
commit
c7bd8eb190
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 22
      middleware.go
  2. 34
      middleware_test.go
  3. 6
      mux.go
  4. 2
      mux_httpserver_test.go
  5. 2
      mux_test.go
  6. 4
      regexp_test.go

22
middleware.go

@ -20,6 +20,28 @@ func (mw MiddlewareFunc) Middleware(handler http.Handler) http.Handler { @@ -20,6 +20,28 @@ func (mw MiddlewareFunc) Middleware(handler http.Handler) http.Handler {
return mw(handler)
}
// NewMiddleware helper method which allows creating a middleware from a function which receives the response and request.
// The typical case would be passing a closure
func NewMiddleware(mw func(http.ResponseWriter, *http.Request)) MiddlewareFunc {
return func(handler http.Handler) http.Handler {
return &genericMW{
process: func(w http.ResponseWriter, r *http.Request) {
mw(w, r)
handler.ServeHTTP(w, r)
},
}
}
}
// genericMW is for creating a middleware from a closure, it's used in the public method NewMiddleware
type genericMW struct {
process func(http.ResponseWriter, *http.Request)
}
func (instance *genericMW) ServeHTTP(w http.ResponseWriter, req *http.Request) {
instance.process(w, req)
}
// Use appends a MiddlewareFunc to the chain. Middleware can be used to intercept or otherwise modify requests and/or responses, and are executed in the order that they are applied to the Router.
func (r *Router) Use(mwf ...MiddlewareFunc) {
for _, fn := range mwf {

34
middleware_test.go

@ -3,6 +3,7 @@ package mux @@ -3,6 +3,7 @@ package mux
import (
"bytes"
"net/http"
"net/http/httptest"
"testing"
)
@ -563,3 +564,36 @@ func TestMiddlewareOnMultiSubrouter(t *testing.T) { @@ -563,3 +564,36 @@ func TestMiddlewareOnMultiSubrouter(t *testing.T) {
}
})
}
func TestNewMiddleware(t *testing.T) {
mwFuncCalls := 0
handler := &mockHandler{
serverHTTP: func(w http.ResponseWriter, r *http.Request) {
if mwFuncCalls != 1 {
t.Fatalf("Expected Handler to be called after mw run. However, the middleware run {%d} times before the handler.", mwFuncCalls)
}
mwFuncCalls++
},
}
newHandler := NewMiddleware(func(http.ResponseWriter, *http.Request) {
mwFuncCalls++
}).Middleware(handler)
newHandler.ServeHTTP(&httptest.ResponseRecorder{}, &http.Request{})
if mwFuncCalls != 2 {
t.Fatalf("Expected mwFuncCalls to be 2, but got {%d}", mwFuncCalls)
}
}
type mockHandler struct {
serverHTTP func(writer http.ResponseWriter, request *http.Request)
}
func (d *mockHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
d.serverHTTP(writer, request)
}

6
mux.go

@ -363,9 +363,9 @@ func (r *Router) Walk(walkFn WalkFunc) error { @@ -363,9 +363,9 @@ func (r *Router) Walk(walkFn WalkFunc) error {
return r.walk(walkFn, []*Route{})
}
// SkipRouter is used as a return value from WalkFuncs to indicate that the
// ErrSkipRouter is used as a return value from WalkFuncs to indicate that the
// router that walk is about to descend down to should be skipped.
var SkipRouter = errors.New("skip this router")
var ErrSkipRouter = errors.New("skip this router")
// WalkFunc is the type of the function called for each route visited by Walk.
// At every invocation, it is given the current route, and the current router,
@ -375,7 +375,7 @@ type WalkFunc func(route *Route, router *Router, ancestors []*Route) error @@ -375,7 +375,7 @@ type WalkFunc func(route *Route, router *Router, ancestors []*Route) error
func (r *Router) walk(walkFn WalkFunc, ancestors []*Route) error {
for _, t := range r.routes {
err := walkFn(t, r, ancestors)
if err == SkipRouter {
if err == ErrSkipRouter {
continue
}
if err != nil {

2
mux_httpserver_test.go

@ -1,5 +1,3 @@ @@ -1,5 +1,3 @@
// +build go1.9
package mux
import (

2
mux_test.go

@ -1600,7 +1600,7 @@ func TestWalkSingleDepth(t *testing.T) { @@ -1600,7 +1600,7 @@ func TestWalkSingleDepth(t *testing.T) {
err := r0.Walk(func(route *Route, router *Router, ancestors []*Route) error {
matcher := route.matchers[0].(*routeRegexp)
if matcher.template == "/d" {
return SkipRouter
return ErrSkipRouter
}
if len(ancestors) != depths[i] {
t.Errorf(`Expected depth of %d at i = %d; got "%d"`, depths[i], i, len(ancestors))

4
regexp_test.go

@ -54,7 +54,7 @@ func Benchmark_findQueryKey(b *testing.B) { @@ -54,7 +54,7 @@ func Benchmark_findQueryKey(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
for key, _ := range all {
for key := range all {
_, _ = findFirstQueryKey(query, key)
}
}
@ -79,7 +79,7 @@ func Benchmark_findQueryKeyGoLib(b *testing.B) { @@ -79,7 +79,7 @@ func Benchmark_findQueryKeyGoLib(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
for key, _ := range all {
for key := range all {
v := u.Query()[key]
if len(v) > 0 {
_ = v[0]

Loading…
Cancel
Save