Browse Source

Merge pull request #169 from ejholmes/go1.7-context

Store vars and route in context.Context when go1.7+ is used
pull/170/head
Kamil Kisiel 10 years ago
parent
commit
e84fac997f
  1. 26
      context_gorilla.go
  2. 40
      context_gorilla_test.go
  3. 24
      context_native.go
  4. 32
      context_native_test.go
  5. 28
      mux.go
  6. 32
      mux_test.go

26
context_gorilla.go

@ -0,0 +1,26 @@
// +build !go1.7
package mux
import (
"net/http"
"github.com/gorilla/context"
)
func contextGet(r *http.Request, key interface{}) interface{} {
return context.Get(r, key)
}
func contextSet(r *http.Request, key, val interface{}) *http.Request {
if val == nil {
return r
}
context.Set(r, key, val)
return r
}
func contextClear(r *http.Request) {
context.Clear(r)
}

40
context_gorilla_test.go

@ -0,0 +1,40 @@
// +build !go1.7
package mux
import (
"net/http"
"testing"
"github.com/gorilla/context"
)
// Tests that the context is cleared or not cleared properly depending on
// the configuration of the router
func TestKeepContext(t *testing.T) {
func1 := func(w http.ResponseWriter, r *http.Request) {}
r := NewRouter()
r.HandleFunc("/", func1).Name("func1")
req, _ := http.NewRequest("GET", "http://localhost/", nil)
context.Set(req, "t", 1)
res := new(http.ResponseWriter)
r.ServeHTTP(*res, req)
if _, ok := context.GetOk(req, "t"); ok {
t.Error("Context should have been cleared at end of request")
}
r.KeepContext = true
req, _ = http.NewRequest("GET", "http://localhost/", nil)
context.Set(req, "t", 1)
r.ServeHTTP(*res, req)
if _, ok := context.GetOk(req, "t"); !ok {
t.Error("Context should NOT have been cleared at end of request")
}
}

24
context_native.go

@ -0,0 +1,24 @@
// +build go1.7
package mux
import (
"context"
"net/http"
)
func contextGet(r *http.Request, key interface{}) interface{} {
return r.Context().Value(key)
}
func contextSet(r *http.Request, key, val interface{}) *http.Request {
if val == nil {
return r
}
return r.WithContext(context.WithValue(r.Context(), key, val))
}
func contextClear(r *http.Request) {
return
}

32
context_native_test.go

@ -0,0 +1,32 @@
// +build go1.7
package mux
import (
"context"
"net/http"
"testing"
"time"
)
func TestNativeContextMiddleware(t *testing.T) {
withTimeout := func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), time.Minute)
defer cancel()
h.ServeHTTP(w, r.WithContext(ctx))
})
}
r := NewRouter()
r.Handle("/path/{foo}", withTimeout(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
vars := Vars(r)
if vars["foo"] != "bar" {
t.Fatal("Expected foo var to be set")
}
})))
rec := NewRecorder()
req := newRequest("GET", "/path/bar")
r.ServeHTTP(rec, req)
}

28
mux.go

@ -10,8 +10,6 @@ import (
"net/http" "net/http"
"path" "path"
"regexp" "regexp"
"github.com/gorilla/context"
) )
// NewRouter returns a new router instance. // NewRouter returns a new router instance.
@ -50,7 +48,9 @@ type Router struct {
strictSlash bool strictSlash bool
// See Router.SkipClean(). This defines the flag for new routes. // See Router.SkipClean(). This defines the flag for new routes.
skipClean bool skipClean bool
// If true, do not clear the request context after handling the request // If true, do not clear the request context after handling the request.
// This has no effect when go1.7+ is used, since the context is stored
// on the request itself.
KeepContext bool KeepContext bool
} }
@ -95,14 +95,14 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
var handler http.Handler var handler http.Handler
if r.Match(req, &match) { if r.Match(req, &match) {
handler = match.Handler handler = match.Handler
setVars(req, match.Vars) req = setVars(req, match.Vars)
setCurrentRoute(req, match.Route) req = setCurrentRoute(req, match.Route)
} }
if handler == nil { if handler == nil {
handler = http.NotFoundHandler() handler = http.NotFoundHandler()
} }
if !r.KeepContext { if !r.KeepContext {
defer context.Clear(req) defer contextClear(req)
} }
handler.ServeHTTP(w, req) handler.ServeHTTP(w, req)
} }
@ -325,7 +325,7 @@ const (
// Vars returns the route variables for the current request, if any. // Vars returns the route variables for the current request, if any.
func Vars(r *http.Request) map[string]string { func Vars(r *http.Request) map[string]string {
if rv := context.Get(r, varsKey); rv != nil { if rv := contextGet(r, varsKey); rv != nil {
return rv.(map[string]string) return rv.(map[string]string)
} }
return nil return nil
@ -337,22 +337,18 @@ func Vars(r *http.Request) map[string]string {
// after the handler returns, unless the KeepContext option is set on the // after the handler returns, unless the KeepContext option is set on the
// Router. // Router.
func CurrentRoute(r *http.Request) *Route { func CurrentRoute(r *http.Request) *Route {
if rv := context.Get(r, routeKey); rv != nil { if rv := contextGet(r, routeKey); rv != nil {
return rv.(*Route) return rv.(*Route)
} }
return nil return nil
} }
func setVars(r *http.Request, val interface{}) { func setVars(r *http.Request, val interface{}) *http.Request {
if val != nil { return contextSet(r, varsKey, val)
context.Set(r, varsKey, val)
}
} }
func setCurrentRoute(r *http.Request, val interface{}) { func setCurrentRoute(r *http.Request, val interface{}) *http.Request {
if val != nil { return contextSet(r, routeKey, val)
context.Set(r, routeKey, val)
}
} }
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------

32
mux_test.go

@ -9,8 +9,6 @@ import (
"net/http" "net/http"
"strings" "strings"
"testing" "testing"
"github.com/gorilla/context"
) )
func (r *Route) GoString() string { func (r *Route) GoString() string {
@ -1316,36 +1314,6 @@ func testTemplate(t *testing.T, test routeTest) {
} }
} }
// Tests that the context is cleared or not cleared properly depending on
// the configuration of the router
func TestKeepContext(t *testing.T) {
func1 := func(w http.ResponseWriter, r *http.Request) {}
r := NewRouter()
r.HandleFunc("/", func1).Name("func1")
req, _ := http.NewRequest("GET", "http://localhost/", nil)
context.Set(req, "t", 1)
res := new(http.ResponseWriter)
r.ServeHTTP(*res, req)
if _, ok := context.GetOk(req, "t"); ok {
t.Error("Context should have been cleared at end of request")
}
r.KeepContext = true
req, _ = http.NewRequest("GET", "http://localhost/", nil)
context.Set(req, "t", 1)
r.ServeHTTP(*res, req)
if _, ok := context.GetOk(req, "t"); !ok {
t.Error("Context should NOT have been cleared at end of request")
}
}
type TestA301ResponseWriter struct { type TestA301ResponseWriter struct {
hh http.Header hh http.Header
status int status int

Loading…
Cancel
Save