diff --git a/mux.go b/mux.go index 73c0996..fe31f09 100644 --- a/mux.go +++ b/mux.go @@ -308,10 +308,12 @@ func (r *Router) walk(walkFn WalkFunc, ancestors []*Route) error { } for _, sr := range t.matchers { if h, ok := sr.(*Router); ok { + ancestors = append(ancestors, t) err := h.walk(walkFn, ancestors) if err != nil { return err } + ancestors = ancestors[:len(ancestors)-1] } } if h, ok := t.handler.(*Router); ok { diff --git a/mux_test.go b/mux_test.go index 6d3bdd2..dd7836b 100644 --- a/mux_test.go +++ b/mux_test.go @@ -11,6 +11,7 @@ import ( "fmt" "net/http" "net/url" + "reflect" "strings" "testing" ) @@ -1382,22 +1383,38 @@ func TestWalkNested(t *testing.T) { l2 := l1.PathPrefix("/l").Subrouter() l2.Path("/a") - paths := []string{"/g", "/g/o", "/g/o/r", "/g/o/r/i", "/g/o/r/i/l", "/g/o/r/i/l/l", "/g/o/r/i/l/l/a"} + testCases := []struct { + path string + ancestors []*Route + }{ + {"/g", []*Route{}}, + {"/g/o", []*Route{g.parent.(*Route)}}, + {"/g/o/r", []*Route{g.parent.(*Route), o.parent.(*Route)}}, + {"/g/o/r/i", []*Route{g.parent.(*Route), o.parent.(*Route), r.parent.(*Route)}}, + {"/g/o/r/i/l", []*Route{g.parent.(*Route), o.parent.(*Route), r.parent.(*Route), i.parent.(*Route)}}, + {"/g/o/r/i/l/l", []*Route{g.parent.(*Route), o.parent.(*Route), r.parent.(*Route), i.parent.(*Route), l1.parent.(*Route)}}, + {"/g/o/r/i/l/l/a", []*Route{g.parent.(*Route), o.parent.(*Route), r.parent.(*Route), i.parent.(*Route), l1.parent.(*Route), l2.parent.(*Route)}}, + } + idx := 0 err := router.Walk(func(route *Route, router *Router, ancestors []*Route) error { - path := paths[idx] + path := testCases[idx].path tpl := route.regexp.path.template if tpl != path { t.Errorf(`Expected %s got %s`, path, tpl) } + currWantAncestors := testCases[idx].ancestors + if !reflect.DeepEqual(currWantAncestors, ancestors) { + t.Errorf(`Expected %+v got %+v`, currWantAncestors, ancestors) + } idx++ return nil }) if err != nil { panic(err) } - if idx != len(paths) { - t.Errorf("Expected %d routes, found %d", len(paths), idx) + if idx != len(testCases) { + t.Errorf("Expected %d routes, found %d", len(testCases), idx) } }