From 5dd56998c22c824ad2e13c50bc3213e85b125134 Mon Sep 17 00:00:00 2001 From: "Eric J. Holmes" Date: Sat, 4 Jun 2016 15:21:55 +0700 Subject: [PATCH 1/2] Add failing context.Context test for go1.7. --- context_test.go | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 context_test.go diff --git a/context_test.go b/context_test.go new file mode 100644 index 0000000..c150edf --- /dev/null +++ b/context_test.go @@ -0,0 +1,32 @@ +// +build go1.7 + +package mux + +import ( + "context" + "net/http" + "testing" + "time" +) + +func TestNativeContextMiddleware(t *testing.T) { + withTimeout := func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx, cancel := context.WithTimeout(r.Context(), time.Minute) + defer cancel() + h.ServeHTTP(w, r.WithContext(ctx)) + }) + } + + r := NewRouter() + r.Handle("/path/{foo}", withTimeout(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + vars := Vars(r) + if vars["foo"] != "bar" { + t.Fatal("Expected foo var to be set") + } + }))) + + rec := NewRecorder() + req := newRequest("GET", "/path/bar") + r.ServeHTTP(rec, req) +} From fdfca9f9173962ed22bc63adc02a1589e43c1405 Mon Sep 17 00:00:00 2001 From: "Eric J. Holmes" Date: Sat, 4 Jun 2016 15:04:23 +0700 Subject: [PATCH 2/2] Support native context.Context when go1.7 is used. --- context_gorilla.go | 26 +++++++++++++++ context_gorilla_test.go | 40 +++++++++++++++++++++++ context_native.go | 24 ++++++++++++++ context_test.go => context_native_test.go | 0 mux.go | 28 +++++++--------- mux_test.go | 32 ------------------ 6 files changed, 102 insertions(+), 48 deletions(-) create mode 100644 context_gorilla.go create mode 100644 context_gorilla_test.go create mode 100644 context_native.go rename context_test.go => context_native_test.go (100%) diff --git a/context_gorilla.go b/context_gorilla.go new file mode 100644 index 0000000..d7adaa8 --- /dev/null +++ b/context_gorilla.go @@ -0,0 +1,26 @@ +// +build !go1.7 + +package mux + +import ( + "net/http" + + "github.com/gorilla/context" +) + +func contextGet(r *http.Request, key interface{}) interface{} { + return context.Get(r, key) +} + +func contextSet(r *http.Request, key, val interface{}) *http.Request { + if val == nil { + return r + } + + context.Set(r, key, val) + return r +} + +func contextClear(r *http.Request) { + context.Clear(r) +} diff --git a/context_gorilla_test.go b/context_gorilla_test.go new file mode 100644 index 0000000..ffaf384 --- /dev/null +++ b/context_gorilla_test.go @@ -0,0 +1,40 @@ +// +build !go1.7 + +package mux + +import ( + "net/http" + "testing" + + "github.com/gorilla/context" +) + +// Tests that the context is cleared or not cleared properly depending on +// the configuration of the router +func TestKeepContext(t *testing.T) { + func1 := func(w http.ResponseWriter, r *http.Request) {} + + r := NewRouter() + r.HandleFunc("/", func1).Name("func1") + + req, _ := http.NewRequest("GET", "http://localhost/", nil) + context.Set(req, "t", 1) + + res := new(http.ResponseWriter) + r.ServeHTTP(*res, req) + + if _, ok := context.GetOk(req, "t"); ok { + t.Error("Context should have been cleared at end of request") + } + + r.KeepContext = true + + req, _ = http.NewRequest("GET", "http://localhost/", nil) + context.Set(req, "t", 1) + + r.ServeHTTP(*res, req) + if _, ok := context.GetOk(req, "t"); !ok { + t.Error("Context should NOT have been cleared at end of request") + } + +} diff --git a/context_native.go b/context_native.go new file mode 100644 index 0000000..209cbea --- /dev/null +++ b/context_native.go @@ -0,0 +1,24 @@ +// +build go1.7 + +package mux + +import ( + "context" + "net/http" +) + +func contextGet(r *http.Request, key interface{}) interface{} { + return r.Context().Value(key) +} + +func contextSet(r *http.Request, key, val interface{}) *http.Request { + if val == nil { + return r + } + + return r.WithContext(context.WithValue(r.Context(), key, val)) +} + +func contextClear(r *http.Request) { + return +} diff --git a/context_test.go b/context_native_test.go similarity index 100% rename from context_test.go rename to context_native_test.go diff --git a/mux.go b/mux.go index 94f5ddd..f8c10f3 100644 --- a/mux.go +++ b/mux.go @@ -10,8 +10,6 @@ import ( "net/http" "path" "regexp" - - "github.com/gorilla/context" ) // NewRouter returns a new router instance. @@ -50,7 +48,9 @@ type Router struct { strictSlash bool // See Router.SkipClean(). This defines the flag for new routes. skipClean bool - // If true, do not clear the request context after handling the request + // If true, do not clear the request context after handling the request. + // This has no effect when go1.7+ is used, since the context is stored + // on the request itself. KeepContext bool } @@ -95,14 +95,14 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { var handler http.Handler if r.Match(req, &match) { handler = match.Handler - setVars(req, match.Vars) - setCurrentRoute(req, match.Route) + req = setVars(req, match.Vars) + req = setCurrentRoute(req, match.Route) } if handler == nil { handler = http.NotFoundHandler() } if !r.KeepContext { - defer context.Clear(req) + defer contextClear(req) } handler.ServeHTTP(w, req) } @@ -325,7 +325,7 @@ const ( // Vars returns the route variables for the current request, if any. func Vars(r *http.Request) map[string]string { - if rv := context.Get(r, varsKey); rv != nil { + if rv := contextGet(r, varsKey); rv != nil { return rv.(map[string]string) } return nil @@ -337,22 +337,18 @@ func Vars(r *http.Request) map[string]string { // after the handler returns, unless the KeepContext option is set on the // Router. func CurrentRoute(r *http.Request) *Route { - if rv := context.Get(r, routeKey); rv != nil { + if rv := contextGet(r, routeKey); rv != nil { return rv.(*Route) } return nil } -func setVars(r *http.Request, val interface{}) { - if val != nil { - context.Set(r, varsKey, val) - } +func setVars(r *http.Request, val interface{}) *http.Request { + return contextSet(r, varsKey, val) } -func setCurrentRoute(r *http.Request, val interface{}) { - if val != nil { - context.Set(r, routeKey, val) - } +func setCurrentRoute(r *http.Request, val interface{}) *http.Request { + return contextSet(r, routeKey, val) } // ---------------------------------------------------------------------------- diff --git a/mux_test.go b/mux_test.go index 777d063..98ac82d 100644 --- a/mux_test.go +++ b/mux_test.go @@ -9,8 +9,6 @@ import ( "net/http" "strings" "testing" - - "github.com/gorilla/context" ) func (r *Route) GoString() string { @@ -1316,36 +1314,6 @@ func testTemplate(t *testing.T, test routeTest) { } } -// Tests that the context is cleared or not cleared properly depending on -// the configuration of the router -func TestKeepContext(t *testing.T) { - func1 := func(w http.ResponseWriter, r *http.Request) {} - - r := NewRouter() - r.HandleFunc("/", func1).Name("func1") - - req, _ := http.NewRequest("GET", "http://localhost/", nil) - context.Set(req, "t", 1) - - res := new(http.ResponseWriter) - r.ServeHTTP(*res, req) - - if _, ok := context.GetOk(req, "t"); ok { - t.Error("Context should have been cleared at end of request") - } - - r.KeepContext = true - - req, _ = http.NewRequest("GET", "http://localhost/", nil) - context.Set(req, "t", 1) - - r.ServeHTTP(*res, req) - if _, ok := context.GetOk(req, "t"); !ok { - t.Error("Context should NOT have been cleared at end of request") - } - -} - type TestA301ResponseWriter struct { hh http.Header status int