Browse Source

Support native context.Context when go1.7 is used.

pull/169/head
Eric J. Holmes 10 years ago
parent
commit
fdfca9f917
  1. 26
      context_gorilla.go
  2. 40
      context_gorilla_test.go
  3. 24
      context_native.go
  4. 0
      context_native_test.go
  5. 28
      mux.go
  6. 32
      mux_test.go

26
context_gorilla.go

@ -0,0 +1,26 @@ @@ -0,0 +1,26 @@
// +build !go1.7
package mux
import (
"net/http"
"github.com/gorilla/context"
)
func contextGet(r *http.Request, key interface{}) interface{} {
return context.Get(r, key)
}
func contextSet(r *http.Request, key, val interface{}) *http.Request {
if val == nil {
return r
}
context.Set(r, key, val)
return r
}
func contextClear(r *http.Request) {
context.Clear(r)
}

40
context_gorilla_test.go

@ -0,0 +1,40 @@ @@ -0,0 +1,40 @@
// +build !go1.7
package mux
import (
"net/http"
"testing"
"github.com/gorilla/context"
)
// Tests that the context is cleared or not cleared properly depending on
// the configuration of the router
func TestKeepContext(t *testing.T) {
func1 := func(w http.ResponseWriter, r *http.Request) {}
r := NewRouter()
r.HandleFunc("/", func1).Name("func1")
req, _ := http.NewRequest("GET", "http://localhost/", nil)
context.Set(req, "t", 1)
res := new(http.ResponseWriter)
r.ServeHTTP(*res, req)
if _, ok := context.GetOk(req, "t"); ok {
t.Error("Context should have been cleared at end of request")
}
r.KeepContext = true
req, _ = http.NewRequest("GET", "http://localhost/", nil)
context.Set(req, "t", 1)
r.ServeHTTP(*res, req)
if _, ok := context.GetOk(req, "t"); !ok {
t.Error("Context should NOT have been cleared at end of request")
}
}

24
context_native.go

@ -0,0 +1,24 @@ @@ -0,0 +1,24 @@
// +build go1.7
package mux
import (
"context"
"net/http"
)
func contextGet(r *http.Request, key interface{}) interface{} {
return r.Context().Value(key)
}
func contextSet(r *http.Request, key, val interface{}) *http.Request {
if val == nil {
return r
}
return r.WithContext(context.WithValue(r.Context(), key, val))
}
func contextClear(r *http.Request) {
return
}

0
context_test.go → context_native_test.go

28
mux.go

@ -10,8 +10,6 @@ import ( @@ -10,8 +10,6 @@ import (
"net/http"
"path"
"regexp"
"github.com/gorilla/context"
)
// NewRouter returns a new router instance.
@ -50,7 +48,9 @@ type Router struct { @@ -50,7 +48,9 @@ type Router struct {
strictSlash bool
// See Router.SkipClean(). This defines the flag for new routes.
skipClean bool
// If true, do not clear the request context after handling the request
// If true, do not clear the request context after handling the request.
// This has no effect when go1.7+ is used, since the context is stored
// on the request itself.
KeepContext bool
}
@ -95,14 +95,14 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { @@ -95,14 +95,14 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
var handler http.Handler
if r.Match(req, &match) {
handler = match.Handler
setVars(req, match.Vars)
setCurrentRoute(req, match.Route)
req = setVars(req, match.Vars)
req = setCurrentRoute(req, match.Route)
}
if handler == nil {
handler = http.NotFoundHandler()
}
if !r.KeepContext {
defer context.Clear(req)
defer contextClear(req)
}
handler.ServeHTTP(w, req)
}
@ -325,7 +325,7 @@ const ( @@ -325,7 +325,7 @@ const (
// Vars returns the route variables for the current request, if any.
func Vars(r *http.Request) map[string]string {
if rv := context.Get(r, varsKey); rv != nil {
if rv := contextGet(r, varsKey); rv != nil {
return rv.(map[string]string)
}
return nil
@ -337,22 +337,18 @@ func Vars(r *http.Request) map[string]string { @@ -337,22 +337,18 @@ func Vars(r *http.Request) map[string]string {
// after the handler returns, unless the KeepContext option is set on the
// Router.
func CurrentRoute(r *http.Request) *Route {
if rv := context.Get(r, routeKey); rv != nil {
if rv := contextGet(r, routeKey); rv != nil {
return rv.(*Route)
}
return nil
}
func setVars(r *http.Request, val interface{}) {
if val != nil {
context.Set(r, varsKey, val)
}
func setVars(r *http.Request, val interface{}) *http.Request {
return contextSet(r, varsKey, val)
}
func setCurrentRoute(r *http.Request, val interface{}) {
if val != nil {
context.Set(r, routeKey, val)
}
func setCurrentRoute(r *http.Request, val interface{}) *http.Request {
return contextSet(r, routeKey, val)
}
// ----------------------------------------------------------------------------

32
mux_test.go

@ -9,8 +9,6 @@ import ( @@ -9,8 +9,6 @@ import (
"net/http"
"strings"
"testing"
"github.com/gorilla/context"
)
func (r *Route) GoString() string {
@ -1316,36 +1314,6 @@ func testTemplate(t *testing.T, test routeTest) { @@ -1316,36 +1314,6 @@ func testTemplate(t *testing.T, test routeTest) {
}
}
// Tests that the context is cleared or not cleared properly depending on
// the configuration of the router
func TestKeepContext(t *testing.T) {
func1 := func(w http.ResponseWriter, r *http.Request) {}
r := NewRouter()
r.HandleFunc("/", func1).Name("func1")
req, _ := http.NewRequest("GET", "http://localhost/", nil)
context.Set(req, "t", 1)
res := new(http.ResponseWriter)
r.ServeHTTP(*res, req)
if _, ok := context.GetOk(req, "t"); ok {
t.Error("Context should have been cleared at end of request")
}
r.KeepContext = true
req, _ = http.NewRequest("GET", "http://localhost/", nil)
context.Set(req, "t", 1)
r.ServeHTTP(*res, req)
if _, ok := context.GetOk(req, "t"); !ok {
t.Error("Context should NOT have been cleared at end of request")
}
}
type TestA301ResponseWriter struct {
hh http.Header
status int

Loading…
Cancel
Save