diff --git a/mux_test.go b/mux_test.go index 48506bf..e455bce 100644 --- a/mux_test.go +++ b/mux_test.go @@ -462,6 +462,15 @@ func TestQueries(t *testing.T) { path: "", shouldMatch: true, }, + { + title: "Queries route, match with a query string out of order", + route: new(Route).Host("www.example.com").Path("/api").Queries("foo", "bar", "baz", "ding"), + request: newRequest("GET", "http://www.example.com/api?baz=ding&foo=bar"), + vars: map[string]string{}, + host: "", + path: "", + shouldMatch: true, + }, { title: "Queries route, bad query", route: new(Route).Queries("foo", "bar", "baz", "ding"), diff --git a/regexp.go b/regexp.go index f1d3147..19a358b 100644 --- a/regexp.go +++ b/regexp.go @@ -34,7 +34,7 @@ func newRouteRegexp(tpl string, matchHost, matchPrefix, matchQuery, strictSlash // Now let's parse it. defaultPattern := "[^/]+" if matchQuery { - defaultPattern = "[^?]+" + defaultPattern = "[^?&]+" matchPrefix, strictSlash = true, false } else if matchHost { defaultPattern = "[^.]+" @@ -51,9 +51,9 @@ func newRouteRegexp(tpl string, matchHost, matchPrefix, matchQuery, strictSlash } varsN := make([]string, len(idxs)/2) varsR := make([]*regexp.Regexp, len(idxs)/2) - pattern := bytes.NewBufferString("^") - if matchQuery { - pattern = bytes.NewBufferString("") + pattern := bytes.NewBufferString("") + if !matchQuery { + pattern.WriteByte('^') } reverse := bytes.NewBufferString("") var end int @@ -209,9 +209,9 @@ func braceIndices(s string) ([]int, error) { // routeRegexpGroup groups the route matchers that carry variables. type routeRegexpGroup struct { - host *routeRegexp - path *routeRegexp - query *routeRegexp + host *routeRegexp + path *routeRegexp + queries []*routeRegexp } // setMatch extracts the variables from the URL once a route matches. @@ -249,11 +249,14 @@ func (v *routeRegexpGroup) setMatch(req *http.Request, m *RouteMatch, r *Route) } } // Store query string variables. - if v.query != nil { - queryVars := v.query.regexp.FindStringSubmatch(req.URL.RawQuery) - if queryVars != nil { - for k, v := range v.query.varsN { - m.Vars[v] = queryVars[k+1] + if v.queries != nil && len(v.queries) > 0 { + rawQuery := req.URL.RawQuery + for _, q := range v.queries { + queryVars := q.regexp.FindStringSubmatch(rawQuery) + if queryVars != nil { + for k, v := range q.varsN { + m.Vars[v] = queryVars[k+1] + } } } } diff --git a/route.go b/route.go index 00989bf..1ac2065 100644 --- a/route.go +++ b/route.go @@ -5,7 +5,6 @@ package mux import ( - "bytes" "errors" "fmt" "net/http" @@ -153,14 +152,16 @@ func (r *Route) addRegexpMatcher(tpl string, matchHost, matchPrefix, matchQuery if err != nil { return err } - if matchHost { - if r.regexp.path != nil { - if err = uniqueVars(rr.varsN, r.regexp.path.varsN); err != nil { + if r.regexp.queries != nil { + for _, q := range r.regexp.queries { + if err = uniqueVars(rr.varsN, q.varsN); err != nil { return err } } - if r.regexp.query != nil { - if err = uniqueVars(rr.varsN, r.regexp.query.varsN); err != nil { + } + if matchHost { + if r.regexp.path != nil { + if err = uniqueVars(rr.varsN, r.regexp.path.varsN); err != nil { return err } } @@ -172,18 +173,13 @@ func (r *Route) addRegexpMatcher(tpl string, matchHost, matchPrefix, matchQuery } } if matchQuery { - if r.regexp.path != nil { - if err = uniqueVars(rr.varsN, r.regexp.path.varsN); err != nil { - return err - } + if r.regexp.queries == nil { + r.regexp.queries = make([]*routeRegexp, 1) + r.regexp.queries[0] = rr + } else { + r.regexp.queries = append(r.regexp.queries, rr) } - r.regexp.query = rr } else { - if r.regexp.query != nil { - if err = uniqueVars(rr.varsN, r.regexp.query.varsN); err != nil { - return err - } - } r.regexp.path = rr } } @@ -345,12 +341,11 @@ func (r *Route) Queries(pairs ...string) *Route { "mux: number of parameters must be multiple of 2, got %v", pairs) return nil } - var buf bytes.Buffer for i := 0; i < length; i += 2 { - buf.WriteString(fmt.Sprintf("%s=%s&", pairs[i], pairs[i+1])) + if r.err = r.addRegexpMatcher(fmt.Sprintf("%s=%s", pairs[i], pairs[i+1]), false, true, true); r.err != nil { + return r + } } - tpl := strings.TrimRight(buf.String(), "&") - r.err = r.addRegexpMatcher(tpl, false, true, true) return r } @@ -527,9 +522,9 @@ func (r *Route) getRegexpGroup() *routeRegexpGroup { } else { // Copy. r.regexp = &routeRegexpGroup{ - host: regexp.host, - path: regexp.path, - query: regexp.query, + host: regexp.host, + path: regexp.path, + queries: regexp.queries, } } }