Browse Source

Resolves conflicts(#515)

pull/517/head
RaviKiran Kilingar 5 years ago
parent
commit
0f80a441ee
No known key found for this signature in database
GPG Key ID: 47D29311AEB6F0A7
  1. 101
      .circleci/config.yml
  2. 18
      context.go
  3. 30
      context_test.go
  4. 15
      middleware.go
  5. 20
      middleware_test.go
  6. 25
      mux.go
  7. 49
      mux_httpserver_test.go
  8. 34
      mux_test.go
  9. 40
      old_test.go
  10. 57
      regexp.go
  11. 91
      regexp_test.go
  12. 36
      route.go
  13. 2
      test_helpers.go

101
.circleci/config.yml

@ -1,87 +1,70 @@
version: 2.0 version: 2.1
jobs: jobs:
# Base test configuration for Go library tests Each distinct version should "test":
# inherit this base, and override (at least) the container image used. parameters:
"test": &test version:
type: string
default: "latest"
golint:
type: boolean
default: true
modules:
type: boolean
default: true
goproxy:
type: string
default: ""
docker: docker:
- image: circleci/golang:latest - image: "circleci/golang:<< parameters.version >>"
working_directory: /go/src/github.com/gorilla/mux working_directory: /go/src/github.com/gorilla/mux
steps: &steps environment:
# Our build steps: we checkout the repo, fetch our deps, lint, and finally GO111MODULE: "on"
# run "go test" on the package. GOPROXY: "<< parameters.goproxy >>"
steps:
- checkout - checkout
# Logs the version in our build logs, for posterity - run:
- run: go version name: "Print the Go version"
command: >
go version
- run: - run:
name: "Fetch dependencies" name: "Fetch dependencies"
command: > command: >
go get -t -v ./... if [[ << parameters.modules >> = true ]]; then
go mod download
export GO111MODULE=on
else
go get -v ./...
fi
# Only run gofmt, vet & lint against the latest Go version # Only run gofmt, vet & lint against the latest Go version
- run: - run:
name: "Run golint" name: "Run golint"
command: > command: >
if [ "${LATEST}" = true ] && [ -z "${SKIP_GOLINT}" ]; then if [ << parameters.version >> = "latest" ] && [ << parameters.golint >> = true ]; then
go get -u golang.org/x/lint/golint go get -u golang.org/x/lint/golint
golint ./... golint ./...
fi fi
- run: - run:
name: "Run gofmt" name: "Run gofmt"
command: > command: >
if [[ "${LATEST}" = true ]]; then if [[ << parameters.version >> = "latest" ]]; then
diff -u <(echo -n) <(gofmt -d -e .) diff -u <(echo -n) <(gofmt -d -e .)
fi fi
- run: - run:
name: "Run go vet" name: "Run go vet"
command: > command: >
if [[ "${LATEST}" = true ]]; then if [[ << parameters.version >> = "latest" ]]; then
go vet -v ./... go vet -v ./...
fi fi
- run: go test -v -race ./... - run:
name: "Run go test (+ race detector)"
"latest": command: >
<<: *test go test -v -race ./...
environment:
LATEST: true
"1.12":
<<: *test
docker:
- image: circleci/golang:1.12
"1.11":
<<: *test
docker:
- image: circleci/golang:1.11
"1.10":
<<: *test
docker:
- image: circleci/golang:1.10
"1.9":
<<: *test
docker:
- image: circleci/golang:1.9
"1.8":
<<: *test
docker:
- image: circleci/golang:1.8
"1.7":
<<: *test
docker:
- image: circleci/golang:1.7
workflows: workflows:
version: 2 tests:
build:
jobs: jobs:
- "latest" - test:
- "1.12" matrix:
- "1.11" parameters:
- "1.10" version: ["latest", "1.15", "1.14", "1.13", "1.12", "1.11"]
- "1.9"
- "1.8"
- "1.7"

18
context.go

@ -1,18 +0,0 @@
package mux
import (
"context"
"net/http"
)
func contextGet(r *http.Request, key interface{}) interface{} {
return r.Context().Value(key)
}
func contextSet(r *http.Request, key, val interface{}) *http.Request {
if val == nil {
return r
}
return r.WithContext(context.WithValue(r.Context(), key, val))
}

30
context_test.go

@ -1,30 +0,0 @@
package mux
import (
"context"
"net/http"
"testing"
"time"
)
func TestNativeContextMiddleware(t *testing.T) {
withTimeout := func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), time.Minute)
defer cancel()
h.ServeHTTP(w, r.WithContext(ctx))
})
}
r := NewRouter()
r.Handle("/path/{foo}", withTimeout(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
vars := Vars(r)
if vars["foo"] != "bar" {
t.Fatal("Expected foo var to be set")
}
})))
rec := NewRecorder()
req := newRequest("GET", "/path/bar")
r.ServeHTTP(rec, req)
}

15
middleware.go

@ -58,22 +58,17 @@ func CORSMethodMiddleware(r *Router) MiddlewareFunc {
func getAllMethodsForRoute(r *Router, req *http.Request) ([]string, error) { func getAllMethodsForRoute(r *Router, req *http.Request) ([]string, error) {
var allMethods []string var allMethods []string
err := r.Walk(func(route *Route, _ *Router, _ []*Route) error { for _, route := range r.routes {
for _, m := range route.matchers { var match RouteMatch
if _, ok := m.(*routeRegexp); ok { if route.Match(req, &match) || match.MatchErr == ErrMethodMismatch {
if m.Match(req, &RouteMatch{}) {
methods, err := route.GetMethods() methods, err := route.GetMethods()
if err != nil { if err != nil {
return err return nil, err
} }
allMethods = append(allMethods, methods...) allMethods = append(allMethods, methods...)
} }
break
} }
}
return nil
})
return allMethods, err return allMethods, nil
} }

20
middleware_test.go

@ -478,6 +478,26 @@ func TestCORSMethodMiddleware(t *testing.T) {
} }
} }
func TestCORSMethodMiddlewareSubrouter(t *testing.T) {
router := NewRouter().StrictSlash(true)
subrouter := router.PathPrefix("/test").Subrouter()
subrouter.HandleFunc("/hello", stringHandler("a")).Methods(http.MethodGet, http.MethodOptions, http.MethodPost)
subrouter.HandleFunc("/hello/{name}", stringHandler("b")).Methods(http.MethodGet, http.MethodOptions)
subrouter.Use(CORSMethodMiddleware(subrouter))
rw := NewRecorder()
req := newRequest("GET", "/test/hello/asdf")
router.ServeHTTP(rw, req)
actualMethods := rw.Header().Get("Access-Control-Allow-Methods")
expectedMethods := "GET,OPTIONS"
if actualMethods != expectedMethods {
t.Fatalf("expected methods %q but got: %q", expectedMethods, actualMethods)
}
}
func TestMiddlewareOnMultiSubrouter(t *testing.T) { func TestMiddlewareOnMultiSubrouter(t *testing.T) {
first := "first" first := "first"
second := "second" second := "second"

25
mux.go

@ -5,6 +5,7 @@
package mux package mux
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
@ -58,8 +59,7 @@ type Router struct {
// If true, do not clear the request context after handling the request. // If true, do not clear the request context after handling the request.
// //
// Deprecated: No effect when go1.7+ is used, since the context is stored // Deprecated: No effect, since the context is stored on the request itself.
// on the request itself.
KeepContext bool KeepContext bool
// Slice of middlewares to be called after a match is found // Slice of middlewares to be called after a match is found
@ -195,8 +195,8 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
var handler http.Handler var handler http.Handler
if r.Match(req, &match) { if r.Match(req, &match) {
handler = match.Handler handler = match.Handler
req = setVars(req, match.Vars) req = requestWithVars(req, match.Vars)
req = setCurrentRoute(req, match.Route) req = requestWithRoute(req, match.Route)
} }
if handler == nil && match.MatchErr == ErrMethodMismatch { if handler == nil && match.MatchErr == ErrMethodMismatch {
@ -426,7 +426,7 @@ const (
// Vars returns the route variables for the current request, if any. // Vars returns the route variables for the current request, if any.
func Vars(r *http.Request) map[string]string { func Vars(r *http.Request) map[string]string {
if rv := contextGet(r, varsKey); rv != nil { if rv := r.Context().Value(varsKey); rv != nil {
return rv.(map[string]string) return rv.(map[string]string)
} }
return nil return nil
@ -435,21 +435,22 @@ func Vars(r *http.Request) map[string]string {
// CurrentRoute returns the matched route for the current request, if any. // CurrentRoute returns the matched route for the current request, if any.
// This only works when called inside the handler of the matched route // This only works when called inside the handler of the matched route
// because the matched route is stored in the request context which is cleared // because the matched route is stored in the request context which is cleared
// after the handler returns, unless the KeepContext option is set on the // after the handler returns.
// Router.
func CurrentRoute(r *http.Request) *Route { func CurrentRoute(r *http.Request) *Route {
if rv := contextGet(r, routeKey); rv != nil { if rv := r.Context().Value(routeKey); rv != nil {
return rv.(*Route) return rv.(*Route)
} }
return nil return nil
} }
func setVars(r *http.Request, val interface{}) *http.Request { func requestWithVars(r *http.Request, vars map[string]string) *http.Request {
return contextSet(r, varsKey, val) ctx := context.WithValue(r.Context(), varsKey, vars)
return r.WithContext(ctx)
} }
func setCurrentRoute(r *http.Request, val interface{}) *http.Request { func requestWithRoute(r *http.Request, route *Route) *http.Request {
return contextSet(r, routeKey, val) ctx := context.WithValue(r.Context(), routeKey, route)
return r.WithContext(ctx)
} }
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------

49
mux_httpserver_test.go

@ -0,0 +1,49 @@
// +build go1.9
package mux
import (
"bytes"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
)
func TestSchemeMatchers(t *testing.T) {
router := NewRouter()
router.HandleFunc("/", func(rw http.ResponseWriter, r *http.Request) {
rw.Write([]byte("hello http world"))
}).Schemes("http")
router.HandleFunc("/", func(rw http.ResponseWriter, r *http.Request) {
rw.Write([]byte("hello https world"))
}).Schemes("https")
assertResponseBody := func(t *testing.T, s *httptest.Server, expectedBody string) {
resp, err := s.Client().Get(s.URL)
if err != nil {
t.Fatalf("unexpected error getting from server: %v", err)
}
if resp.StatusCode != 200 {
t.Fatalf("expected a status code of 200, got %v", resp.StatusCode)
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("unexpected error reading body: %v", err)
}
if !bytes.Equal(body, []byte(expectedBody)) {
t.Fatalf("response should be hello world, was: %q", string(body))
}
}
t.Run("httpServer", func(t *testing.T) {
s := httptest.NewServer(router)
defer s.Close()
assertResponseBody(t, s, "hello http world")
})
t.Run("httpsServer", func(t *testing.T) {
s := httptest.NewTLSServer(router)
defer s.Close()
assertResponseBody(t, s, "hello https world")
})
}

34
mux_test.go

@ -7,14 +7,17 @@ package mux
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"context"
"errors" "errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/http/httptest"
"net/url" "net/url"
"reflect" "reflect"
"strings" "strings"
"testing" "testing"
"time"
) )
func (r *Route) GoString() string { func (r *Route) GoString() string {
@ -681,8 +684,8 @@ func TestHeaders(t *testing.T) {
}, },
{ {
title: "Headers route, regex header values to match", title: "Headers route, regex header values to match",
route: new(Route).Headers("foo", "ba[zr]"), route: new(Route).HeadersRegexp("foo", "ba[zr]"),
request: newRequestHeaders("GET", "http://localhost", map[string]string{"foo": "bar"}), request: newRequestHeaders("GET", "http://localhost", map[string]string{"foo": "baw"}),
vars: map[string]string{}, vars: map[string]string{},
host: "", host: "",
path: "", path: "",
@ -2803,6 +2806,28 @@ func TestSubrouterNotFound(t *testing.T) {
} }
} }
func TestContextMiddleware(t *testing.T) {
withTimeout := func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), time.Minute)
defer cancel()
h.ServeHTTP(w, r.WithContext(ctx))
})
}
r := NewRouter()
r.Handle("/path/{foo}", withTimeout(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
vars := Vars(r)
if vars["foo"] != "bar" {
t.Fatal("Expected foo var to be set")
}
})))
rec := NewRecorder()
req := newRequest("GET", "/path/bar")
r.ServeHTTP(rec, req)
}
// testOptionsMiddleWare returns 200 on an OPTIONS request // testOptionsMiddleWare returns 200 on an OPTIONS request
func testOptionsMiddleWare(inner http.Handler) http.Handler { func testOptionsMiddleWare(inner http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -2937,10 +2962,7 @@ func newRequestWithHeaders(method, url string, headers ...string) *http.Request
// newRequestHost a new request with a method, url, and host header // newRequestHost a new request with a method, url, and host header
func newRequestHost(method, url, host string) *http.Request { func newRequestHost(method, url, host string) *http.Request {
req, err := http.NewRequest(method, url, nil) req := httptest.NewRequest(method, url, nil)
if err != nil {
panic(err)
}
req.Host = host req.Host = host
return req return req
} }

40
old_test.go

@ -260,6 +260,18 @@ var hostMatcherTests = []hostMatcherTest{
vars: map[string]string{"foo": "abc", "bar": "def", "baz": "ghi"}, vars: map[string]string{"foo": "abc", "bar": "def", "baz": "ghi"},
result: true, result: true,
}, },
{
matcher: NewRouter().NewRoute().Host("{foo:[a-z][a-z][a-z]}.{bar:[a-z][a-z][a-z]}.{baz:[a-z][a-z][a-z]}:{port:.*}"),
url: "http://abc.def.ghi:65535/",
vars: map[string]string{"foo": "abc", "bar": "def", "baz": "ghi", "port": "65535"},
result: true,
},
{
matcher: NewRouter().NewRoute().Host("{foo:[a-z][a-z][a-z]}.{bar:[a-z][a-z][a-z]}.{baz:[a-z][a-z][a-z]}"),
url: "http://abc.def.ghi:65535/",
vars: map[string]string{"foo": "abc", "bar": "def", "baz": "ghi"},
result: true,
},
{ {
matcher: NewRouter().NewRoute().Host("{foo:[a-z][a-z][a-z]}.{bar:[a-z][a-z][a-z]}.{baz:[a-z][a-z][a-z]}"), matcher: NewRouter().NewRoute().Host("{foo:[a-z][a-z][a-z]}.{bar:[a-z][a-z][a-z]}.{baz:[a-z][a-z][a-z]}"),
url: "http://a.b.c/", url: "http://a.b.c/",
@ -365,6 +377,11 @@ var urlBuildingTests = []urlBuildingTest{
vars: []string{"subdomain", "bar"}, vars: []string{"subdomain", "bar"},
url: "http://bar.domain.com", url: "http://bar.domain.com",
}, },
{
route: new(Route).Host("{subdomain}.domain.com:{port:.*}"),
vars: []string{"subdomain", "bar", "port", "65535"},
url: "http://bar.domain.com:65535",
},
{ {
route: new(Route).Host("foo.domain.com").Path("/articles"), route: new(Route).Host("foo.domain.com").Path("/articles"),
vars: []string{}, vars: []string{},
@ -385,6 +402,11 @@ var urlBuildingTests = []urlBuildingTest{
vars: []string{"subdomain", "foo", "category", "technology", "id", "42"}, vars: []string{"subdomain", "foo", "category", "technology", "id", "42"},
url: "http://foo.domain.com/articles/technology/42", url: "http://foo.domain.com/articles/technology/42",
}, },
{
route: new(Route).Host("example.com").Schemes("https", "http"),
vars: []string{},
url: "https://example.com",
},
} }
func TestHeaderMatcher(t *testing.T) { func TestHeaderMatcher(t *testing.T) {
@ -407,7 +429,11 @@ func TestHeaderMatcher(t *testing.T) {
func TestHostMatcher(t *testing.T) { func TestHostMatcher(t *testing.T) {
for _, v := range hostMatcherTests { for _, v := range hostMatcherTests {
request, _ := http.NewRequest("GET", v.url, nil) request, err := http.NewRequest("GET", v.url, nil)
if err != nil {
t.Errorf("http.NewRequest failed %#v", err)
continue
}
var routeMatch RouteMatch var routeMatch RouteMatch
result := v.matcher.Match(request, &routeMatch) result := v.matcher.Match(request, &routeMatch)
vars := routeMatch.Vars vars := routeMatch.Vars
@ -502,18 +528,6 @@ func TestUrlBuilding(t *testing.T) {
url := u.String() url := u.String()
if url != v.url { if url != v.url {
t.Errorf("expected %v, got %v", v.url, url) t.Errorf("expected %v, got %v", v.url, url)
/*
reversePath := ""
reverseHost := ""
if v.route.pathTemplate != nil {
reversePath = v.route.pathTemplate.Reverse
}
if v.route.hostTemplate != nil {
reverseHost = v.route.hostTemplate.Reverse
}
t.Errorf("%#v:\nexpected: %q\ngot: %q\nreverse path: %q\nreverse host: %q", v.route, v.url, url, reversePath, reverseHost)
*/
} }
} }

57
regexp.go

@ -181,7 +181,8 @@ func (r *routeRegexp) Match(req *http.Request, match *RouteMatch) bool {
} }
} }
return r.regexp.MatchString(host) return r.regexp.MatchString(host)
} else { }
if r.regexpType == regexpTypeQuery { if r.regexpType == regexpTypeQuery {
return r.matchQueryString(req) return r.matchQueryString(req)
} }
@ -190,12 +191,11 @@ func (r *routeRegexp) Match(req *http.Request, match *RouteMatch) bool {
path = req.URL.EscapedPath() path = req.URL.EscapedPath()
} }
return r.regexp.MatchString(path) return r.regexp.MatchString(path)
}
} }
// url builds a URL part using the given values. // url builds a URL part using the given values.
func (r *routeRegexp) url(values map[string]string) (string, error) { func (r *routeRegexp) url(values map[string]string) (string, error) {
urlValues := make([]interface{}, len(r.varsN)) urlValues := make([]interface{}, len(r.varsN), len(r.varsN))
for k, v := range r.varsN { for k, v := range r.varsN {
value, ok := values[v] value, ok := values[v]
if !ok { if !ok {
@ -230,14 +230,51 @@ func (r *routeRegexp) getURLQuery(req *http.Request) string {
return "" return ""
} }
templateKey := strings.SplitN(r.template, "=", 2)[0] templateKey := strings.SplitN(r.template, "=", 2)[0]
for key, vals := range req.URL.Query() { val, ok := findFirstQueryKey(req.URL.RawQuery, templateKey)
if key == templateKey && len(vals) > 0 { if ok {
return key + "=" + vals[0] return templateKey + "=" + val
}
} }
return "" return ""
} }
// findFirstQueryKey returns the same result as (*url.URL).Query()[key][0].
// If key was not found, empty string and false is returned.
func findFirstQueryKey(rawQuery, key string) (value string, ok bool) {
query := []byte(rawQuery)
for len(query) > 0 {
foundKey := query
if i := bytes.IndexAny(foundKey, "&;"); i >= 0 {
foundKey, query = foundKey[:i], foundKey[i+1:]
} else {
query = query[:0]
}
if len(foundKey) == 0 {
continue
}
var value []byte
if i := bytes.IndexByte(foundKey, '='); i >= 0 {
foundKey, value = foundKey[:i], foundKey[i+1:]
}
if len(foundKey) < len(key) {
// Cannot possibly be key.
continue
}
keyString, err := url.QueryUnescape(string(foundKey))
if err != nil {
continue
}
if keyString != key {
continue
}
valueString, err := url.QueryUnescape(string(value))
if err != nil {
continue
}
return valueString, true
}
return "", false
}
func (r *routeRegexp) matchQueryString(req *http.Request) bool { func (r *routeRegexp) matchQueryString(req *http.Request) bool {
return r.regexp.MatchString(r.getURLQuery(req)) return r.regexp.MatchString(r.getURLQuery(req))
} }
@ -288,6 +325,12 @@ func (v routeRegexpGroup) setMatch(req *http.Request, m *RouteMatch, r *Route) {
// Store host variables. // Store host variables.
if v.host != nil { if v.host != nil {
host := getHost(req) host := getHost(req)
if v.host.wildcardHostPort {
// Don't be strict on the port match
if i := strings.Index(host, ":"); i != -1 {
host = host[:i]
}
}
matches := v.host.regexp.FindStringSubmatchIndex(host) matches := v.host.regexp.FindStringSubmatchIndex(host)
if len(matches) > 0 { if len(matches) > 0 {
extractVars(host, matches, v.host.varsN, m.Vars) extractVars(host, matches, v.host.varsN, m.Vars)

91
regexp_test.go

@ -0,0 +1,91 @@
package mux
import (
"net/url"
"reflect"
"strconv"
"testing"
)
func Test_findFirstQueryKey(t *testing.T) {
tests := []string{
"a=1&b=2",
"a=1&a=2&a=banana",
"ascii=%3Ckey%3A+0x90%3E",
"a=1;b=2",
"a=1&a=2;a=banana",
"a==",
"a=%2",
"a=20&%20%3F&=%23+%25%21%3C%3E%23%22%7B%7D%7C%5C%5E%5B%5D%60%E2%98%BA%09:%2F@$%27%28%29%2A%2C%3B&a=30",
"a=1& ?&=#+%!<>#\"{}|\\^[]`☺\t:/@$'()*,;&a=5",
"a=xxxxxxxxxxxxxxxx&b=YYYYYYYYYYYYYYY&c=ppppppppppppppppppp&f=ttttttttttttttttt&a=uuuuuuuuuuuuu",
}
for _, query := range tests {
t.Run(query, func(t *testing.T) {
// Check against url.ParseQuery, ignoring the error.
all, _ := url.ParseQuery(query)
for key, want := range all {
t.Run(key, func(t *testing.T) {
got, ok := findFirstQueryKey(query, key)
if !ok {
t.Error("Did not get expected key", key)
}
if !reflect.DeepEqual(got, want[0]) {
t.Errorf("findFirstQueryKey(%s,%s) = %v, want %v", query, key, got, want[0])
}
})
}
})
}
}
func Benchmark_findQueryKey(b *testing.B) {
tests := []string{
"a=1&b=2",
"ascii=%3Ckey%3A+0x90%3E",
"a=20&%20%3F&=%23+%25%21%3C%3E%23%22%7B%7D%7C%5C%5E%5B%5D%60%E2%98%BA%09:%2F@$%27%28%29%2A%2C%3B&a=30",
"a=xxxxxxxxxxxxxxxx&bbb=YYYYYYYYYYYYYYY&cccc=ppppppppppppppppppp&ddddd=ttttttttttttttttt&a=uuuuuuuuuuuuu",
"a=;b=;c=;d=;e=;f=;g=;h=;i=,j=;k=",
}
for i, query := range tests {
b.Run(strconv.Itoa(i), func(b *testing.B) {
// Check against url.ParseQuery, ignoring the error.
all, _ := url.ParseQuery(query)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
for key, _ := range all {
_, _ = findFirstQueryKey(query, key)
}
}
})
}
}
func Benchmark_findQueryKeyGoLib(b *testing.B) {
tests := []string{
"a=1&b=2",
"ascii=%3Ckey%3A+0x90%3E",
"a=20&%20%3F&=%23+%25%21%3C%3E%23%22%7B%7D%7C%5C%5E%5B%5D%60%E2%98%BA%09:%2F@$%27%28%29%2A%2C%3B&a=30",
"a=xxxxxxxxxxxxxxxx&bbb=YYYYYYYYYYYYYYY&cccc=ppppppppppppppppppp&ddddd=ttttttttttttttttt&a=uuuuuuuuuuuuu",
"a=;b=;c=;d=;e=;f=;g=;h=;i=,j=;k=",
}
for i, query := range tests {
b.Run(strconv.Itoa(i), func(b *testing.B) {
// Check against url.ParseQuery, ignoring the error.
all, _ := url.ParseQuery(query)
var u url.URL
u.RawQuery = query
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
for key, _ := range all {
v := u.Query()[key]
if len(v) > 0 {
_ = v[0]
}
}
}
})
}
}

36
route.go

@ -414,11 +414,30 @@ func (r *Route) Queries(pairs ...string) *Route {
type schemeMatcher []string type schemeMatcher []string
func (m schemeMatcher) Match(r *http.Request, match *RouteMatch) bool { func (m schemeMatcher) Match(r *http.Request, match *RouteMatch) bool {
return matchInArray(m, r.URL.Scheme) scheme := r.URL.Scheme
// https://golang.org/pkg/net/http/#Request
// "For [most] server requests, fields other than Path and RawQuery will be
// empty."
// Since we're an http muxer, the scheme is either going to be http or https
// though, so we can just set it based on the tls termination state.
if scheme == "" {
if r.TLS == nil {
scheme = "http"
} else {
scheme = "https"
}
}
return matchInArray(m, scheme)
} }
// Schemes adds a matcher for URL schemes. // Schemes adds a matcher for URL schemes.
// It accepts a sequence of schemes to be matched, e.g.: "http", "https". // It accepts a sequence of schemes to be matched, e.g.: "http", "https".
// If the request's URL has a scheme set, it will be matched against.
// Generally, the URL scheme will only be set if a previous handler set it,
// such as the ProxyHeaders handler from gorilla/handlers.
// If unset, the scheme will be determined based on the request's TLS
// termination state.
// The first argument to Schemes will be used when constructing a route URL.
func (r *Route) Schemes(schemes ...string) *Route { func (r *Route) Schemes(schemes ...string) *Route {
for k, v := range schemes { for k, v := range schemes {
schemes[k] = strings.ToLower(v) schemes[k] = strings.ToLower(v)
@ -495,8 +514,8 @@ func (r *Route) Subrouter() *Router {
// This also works for host variables: // This also works for host variables:
// //
// r := mux.NewRouter() // r := mux.NewRouter()
// r.Host("{subdomain}.domain.com"). // r.HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler).
// HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler). // Host("{subdomain}.domain.com").
// Name("article") // Name("article")
// //
// // url.String() will be "http://news.domain.com/articles/technology/42" // // url.String() will be "http://news.domain.com/articles/technology/42"
@ -504,6 +523,13 @@ func (r *Route) Subrouter() *Router {
// "category", "technology", // "category", "technology",
// "id", "42") // "id", "42")
// //
// The scheme of the resulting url will be the first argument that was passed to Schemes:
//
// // url.String() will be "https://example.com"
// r := mux.NewRouter()
// url, err := r.Host("example.com")
// .Schemes("https", "http").URL()
//
// All variables defined in the route are required, and their values must // All variables defined in the route are required, and their values must
// conform to the corresponding patterns. // conform to the corresponding patterns.
func (r *Route) URL(pairs ...string) (*url.URL, error) { func (r *Route) URL(pairs ...string) (*url.URL, error) {
@ -637,7 +663,7 @@ func (r *Route) GetQueriesRegexp() ([]string, error) {
if r.regexp.queries == nil { if r.regexp.queries == nil {
return nil, errors.New("mux: route doesn't have queries") return nil, errors.New("mux: route doesn't have queries")
} }
var queries []string queries := make([]string, 0, len(r.regexp.queries))
for _, query := range r.regexp.queries { for _, query := range r.regexp.queries {
queries = append(queries, query.regexp.String()) queries = append(queries, query.regexp.String())
} }
@ -656,7 +682,7 @@ func (r *Route) GetQueriesTemplates() ([]string, error) {
if r.regexp.queries == nil { if r.regexp.queries == nil {
return nil, errors.New("mux: route doesn't have queries") return nil, errors.New("mux: route doesn't have queries")
} }
var queries []string queries := make([]string, 0, len(r.regexp.queries))
for _, query := range r.regexp.queries { for _, query := range r.regexp.queries {
queries = append(queries, query.template) queries = append(queries, query.template)
} }

2
test_helpers.go

@ -15,5 +15,5 @@ import "net/http"
// can be set by making a route that captures the required variables, // can be set by making a route that captures the required variables,
// starting a server and sending the request to that server. // starting a server and sending the request to that server.
func SetURLVars(r *http.Request, val map[string]string) *http.Request { func SetURLVars(r *http.Request, val map[string]string) *http.Request {
return setVars(r, val) return requestWithVars(r, val)
} }

Loading…
Cancel
Save