diff --git a/middleware_test.go b/middleware_test.go index 24016cb..30df274 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -2,6 +2,7 @@ package mux import ( "bytes" + "fmt" "net/http" "net/http/httptest" "testing" @@ -28,12 +29,12 @@ func TestMiddlewareAdd(t *testing.T) { router.useInterface(mw) if len(router.middlewares) != 1 || router.middlewares[0] != mw { - t.Fatal("Middleware was not added correctly") + t.Fatal("Middleware interface was not added correctly") } router.Use(mw.Middleware) if len(router.middlewares) != 2 { - t.Fatal("MiddlewareFunc method was not added correctly") + t.Fatal("Middleware method was not added correctly") } banalMw := func(handler http.Handler) http.Handler { @@ -41,7 +42,7 @@ func TestMiddlewareAdd(t *testing.T) { } router.Use(banalMw) if len(router.middlewares) != 3 { - t.Fatal("MiddlewareFunc method was not added correctly") + t.Fatal("Middleware function was not added correctly") } } @@ -55,34 +56,37 @@ func TestMiddleware(t *testing.T) { rw := NewRecorder() req := newRequest("GET", "/") - // Test regular middleware call - router.ServeHTTP(rw, req) - if mw.timesCalled != 1 { - t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled) - } - - // Middleware should not be called for 404 - req = newRequest("GET", "/not/found") - router.ServeHTTP(rw, req) - if mw.timesCalled != 1 { - t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled) - } + t.Run("regular middleware call", func(t *testing.T) { + router.ServeHTTP(rw, req) + if mw.timesCalled != 1 { + t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled) + } + }) - // Middleware should not be called if there is a method mismatch - req = newRequest("POST", "/") - router.ServeHTTP(rw, req) - if mw.timesCalled != 1 { - t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled) - } + t.Run("not called for 404", func(t *testing.T) { + req = newRequest("GET", "/not/found") + router.ServeHTTP(rw, req) + if mw.timesCalled != 1 { + t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled) + } + }) - // Add the middleware again as function - router.Use(mw.Middleware) - req = newRequest("GET", "/") - router.ServeHTTP(rw, req) - if mw.timesCalled != 3 { - t.Fatalf("Expected %d calls, but got only %d", 3, mw.timesCalled) - } + t.Run("not called for method mismatch", func(t *testing.T) { + req = newRequest("POST", "/") + router.ServeHTTP(rw, req) + if mw.timesCalled != 1 { + t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled) + } + }) + t.Run("regular call using function middleware", func(t *testing.T) { + router.Use(mw.Middleware) + req = newRequest("GET", "/") + router.ServeHTTP(rw, req) + if mw.timesCalled != 3 { + t.Fatalf("Expected %d calls, but got only %d", 3, mw.timesCalled) + } + }) } func TestMiddlewareSubrouter(t *testing.T) { @@ -98,42 +102,56 @@ func TestMiddlewareSubrouter(t *testing.T) { rw := NewRecorder() req := newRequest("GET", "/") - router.ServeHTTP(rw, req) - if mw.timesCalled != 0 { - t.Fatalf("Expected %d calls, but got only %d", 0, mw.timesCalled) - } + t.Run("not called for route outside subrouter", func(t *testing.T) { + router.ServeHTTP(rw, req) + if mw.timesCalled != 0 { + t.Fatalf("Expected %d calls, but got only %d", 0, mw.timesCalled) + } + }) - req = newRequest("GET", "/sub/") - router.ServeHTTP(rw, req) - if mw.timesCalled != 0 { - t.Fatalf("Expected %d calls, but got only %d", 0, mw.timesCalled) - } + t.Run("not called for subrouter root 404", func(t *testing.T) { + req = newRequest("GET", "/sub/") + router.ServeHTTP(rw, req) + if mw.timesCalled != 0 { + t.Fatalf("Expected %d calls, but got only %d", 0, mw.timesCalled) + } + }) - req = newRequest("GET", "/sub/x") - router.ServeHTTP(rw, req) - if mw.timesCalled != 1 { - t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled) - } + t.Run("called once for route inside subrouter", func(t *testing.T) { + req = newRequest("GET", "/sub/x") + router.ServeHTTP(rw, req) + if mw.timesCalled != 1 { + t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled) + } + }) - req = newRequest("GET", "/sub/not/found") - router.ServeHTTP(rw, req) - if mw.timesCalled != 1 { - t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled) - } + t.Run("not called for 404 inside subrouter", func(t *testing.T) { + req = newRequest("GET", "/sub/not/found") + router.ServeHTTP(rw, req) + if mw.timesCalled != 1 { + t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled) + } + }) - router.useInterface(mw) + t.Run("middleware added to router", func(t *testing.T) { + router.useInterface(mw) - req = newRequest("GET", "/") - router.ServeHTTP(rw, req) - if mw.timesCalled != 2 { - t.Fatalf("Expected %d calls, but got only %d", 2, mw.timesCalled) - } + t.Run("called once for route outside subrouter", func(t *testing.T) { + req = newRequest("GET", "/") + router.ServeHTTP(rw, req) + if mw.timesCalled != 2 { + t.Fatalf("Expected %d calls, but got only %d", 2, mw.timesCalled) + } + }) - req = newRequest("GET", "/sub/x") - router.ServeHTTP(rw, req) - if mw.timesCalled != 4 { - t.Fatalf("Expected %d calls, but got only %d", 4, mw.timesCalled) - } + t.Run("called twice for route inside subrouter", func(t *testing.T) { + req = newRequest("GET", "/sub/x") + router.ServeHTTP(rw, req) + if mw.timesCalled != 4 { + t.Fatalf("Expected %d calls, but got only %d", 4, mw.timesCalled) + } + }) + }) } func TestMiddlewareExecution(t *testing.T) { @@ -145,30 +163,33 @@ func TestMiddlewareExecution(t *testing.T) { w.Write(handlerStr) }) - rw := NewRecorder() - req := newRequest("GET", "/") + t.Run("responds normally without middleware", func(t *testing.T) { + rw := NewRecorder() + req := newRequest("GET", "/") - // Test handler-only call - router.ServeHTTP(rw, req) + router.ServeHTTP(rw, req) - if !bytes.Equal(rw.Body.Bytes(), handlerStr) { - t.Fatal("Handler response is not what it should be") - } + if !bytes.Equal(rw.Body.Bytes(), handlerStr) { + t.Fatal("Handler response is not what it should be") + } + }) - // Test middleware call - rw = NewRecorder() + t.Run("responds with handler and middleware response", func(t *testing.T) { + rw := NewRecorder() + req := newRequest("GET", "/") - router.Use(func(h http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Write(mwStr) - h.ServeHTTP(w, r) + router.Use(func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write(mwStr) + h.ServeHTTP(w, r) + }) }) - }) - router.ServeHTTP(rw, req) - if !bytes.Equal(rw.Body.Bytes(), append(mwStr, handlerStr...)) { - t.Fatal("Middleware + handler response is not what it should be") - } + router.ServeHTTP(rw, req) + if !bytes.Equal(rw.Body.Bytes(), append(mwStr, handlerStr...)) { + t.Fatal("Middleware + handler response is not what it should be") + } + }) } func TestMiddlewareNotFound(t *testing.T) { @@ -187,26 +208,29 @@ func TestMiddlewareNotFound(t *testing.T) { }) // Test not found call with default handler - rw := NewRecorder() - req := newRequest("GET", "/notfound") + t.Run("not called", func(t *testing.T) { + rw := NewRecorder() + req := newRequest("GET", "/notfound") - router.ServeHTTP(rw, req) - if bytes.Contains(rw.Body.Bytes(), mwStr) { - t.Fatal("Middleware was called for a 404") - } + router.ServeHTTP(rw, req) + if bytes.Contains(rw.Body.Bytes(), mwStr) { + t.Fatal("Middleware was called for a 404") + } + }) - // Test not found call with custom handler - rw = NewRecorder() - req = newRequest("GET", "/notfound") + t.Run("not called with custom not found handler", func(t *testing.T) { + rw := NewRecorder() + req := newRequest("GET", "/notfound") - router.NotFoundHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - rw.Write([]byte("Custom 404 handler")) - }) - router.ServeHTTP(rw, req) + router.NotFoundHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + rw.Write([]byte("Custom 404 handler")) + }) + router.ServeHTTP(rw, req) - if bytes.Contains(rw.Body.Bytes(), mwStr) { - t.Fatal("Middleware was called for a custom 404") - } + if bytes.Contains(rw.Body.Bytes(), mwStr) { + t.Fatal("Middleware was called for a custom 404") + } + }) } func TestMiddlewareMethodMismatch(t *testing.T) { @@ -225,27 +249,29 @@ func TestMiddlewareMethodMismatch(t *testing.T) { }) }) - // Test method mismatch - rw := NewRecorder() - req := newRequest("POST", "/") + t.Run("not called", func(t *testing.T) { + rw := NewRecorder() + req := newRequest("POST", "/") - router.ServeHTTP(rw, req) - if bytes.Contains(rw.Body.Bytes(), mwStr) { - t.Fatal("Middleware was called for a method mismatch") - } + router.ServeHTTP(rw, req) + if bytes.Contains(rw.Body.Bytes(), mwStr) { + t.Fatal("Middleware was called for a method mismatch") + } + }) - // Test not found call - rw = NewRecorder() - req = newRequest("POST", "/") + t.Run("not called with custom method not allowed handler", func(t *testing.T) { + rw := NewRecorder() + req := newRequest("POST", "/") - router.MethodNotAllowedHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - rw.Write([]byte("Method not allowed")) - }) - router.ServeHTTP(rw, req) + router.MethodNotAllowedHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + rw.Write([]byte("Method not allowed")) + }) + router.ServeHTTP(rw, req) - if bytes.Contains(rw.Body.Bytes(), mwStr) { - t.Fatal("Middleware was called for a method mismatch") - } + if bytes.Contains(rw.Body.Bytes(), mwStr) { + t.Fatal("Middleware was called for a method mismatch") + } + }) } func TestMiddlewareNotFoundSubrouter(t *testing.T) { @@ -269,27 +295,29 @@ func TestMiddlewareNotFoundSubrouter(t *testing.T) { }) }) - // Test not found call for default handler - rw := NewRecorder() - req := newRequest("GET", "/sub/notfound") + t.Run("not called", func(t *testing.T) { + rw := NewRecorder() + req := newRequest("GET", "/sub/notfound") - router.ServeHTTP(rw, req) - if bytes.Contains(rw.Body.Bytes(), mwStr) { - t.Fatal("Middleware was called for a 404") - } + router.ServeHTTP(rw, req) + if bytes.Contains(rw.Body.Bytes(), mwStr) { + t.Fatal("Middleware was called for a 404") + } + }) - // Test not found call with custom handler - rw = NewRecorder() - req = newRequest("GET", "/sub/notfound") + t.Run("not called with custom not found handler", func(t *testing.T) { + rw := NewRecorder() + req := newRequest("GET", "/sub/notfound") - subrouter.NotFoundHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - rw.Write([]byte("Custom 404 handler")) - }) - router.ServeHTTP(rw, req) + subrouter.NotFoundHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + rw.Write([]byte("Custom 404 handler")) + }) + router.ServeHTTP(rw, req) - if bytes.Contains(rw.Body.Bytes(), mwStr) { - t.Fatal("Middleware was called for a custom 404") - } + if bytes.Contains(rw.Body.Bytes(), mwStr) { + t.Fatal("Middleware was called for a custom 404") + } + }) } func TestMiddlewareMethodMismatchSubrouter(t *testing.T) { @@ -313,27 +341,29 @@ func TestMiddlewareMethodMismatchSubrouter(t *testing.T) { }) }) - // Test method mismatch without custom handler - rw := NewRecorder() - req := newRequest("POST", "/sub/") + t.Run("not called", func(t *testing.T) { + rw := NewRecorder() + req := newRequest("POST", "/sub/") - router.ServeHTTP(rw, req) - if bytes.Contains(rw.Body.Bytes(), mwStr) { - t.Fatal("Middleware was called for a method mismatch") - } + router.ServeHTTP(rw, req) + if bytes.Contains(rw.Body.Bytes(), mwStr) { + t.Fatal("Middleware was called for a method mismatch") + } + }) - // Test method mismatch with custom handler - rw = NewRecorder() - req = newRequest("POST", "/sub/") + t.Run("not called with custom method not allowed handler", func(t *testing.T) { + rw := NewRecorder() + req := newRequest("POST", "/sub/") - router.MethodNotAllowedHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - rw.Write([]byte("Method not allowed")) - }) - router.ServeHTTP(rw, req) + router.MethodNotAllowedHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + rw.Write([]byte("Method not allowed")) + }) + router.ServeHTTP(rw, req) - if bytes.Contains(rw.Body.Bytes(), mwStr) { - t.Fatal("Middleware was called for a method mismatch") - } + if bytes.Contains(rw.Body.Bytes(), mwStr) { + t.Fatal("Middleware was called for a method mismatch") + } + }) } func TestCORSMethodMiddleware(t *testing.T) { @@ -358,21 +388,23 @@ func TestCORSMethodMiddleware(t *testing.T) { router.Use(CORSMethodMiddleware(router)) - for _, tt := range cases { - rr := httptest.NewRecorder() - req := newRequest(tt.method, tt.testURL) + for i, tt := range cases { + t.Run(fmt.Sprintf("cases[%d]", i), func(t *testing.T) { + rr := httptest.NewRecorder() + req := newRequest(tt.method, tt.testURL) - router.ServeHTTP(rr, req) + router.ServeHTTP(rr, req) - if rr.Body.String() != tt.response { - t.Errorf("Expected body '%s', found '%s'", tt.response, rr.Body.String()) - } + if rr.Body.String() != tt.response { + t.Errorf("Expected body '%s', found '%s'", tt.response, rr.Body.String()) + } - allowedMethods := rr.Header().Get("Access-Control-Allow-Methods") + allowedMethods := rr.Header().Get("Access-Control-Allow-Methods") - if allowedMethods != tt.expectedAllowedMethods { - t.Errorf("Expected Access-Control-Allow-Methods '%s', found '%s'", tt.expectedAllowedMethods, allowedMethods) - } + if allowedMethods != tt.expectedAllowedMethods { + t.Errorf("Expected Access-Control-Allow-Methods '%s', found '%s'", tt.expectedAllowedMethods, allowedMethods) + } + }) } } @@ -411,27 +443,33 @@ func TestMiddlewareOnMultiSubrouter(t *testing.T) { }) }) - rw := NewRecorder() - req := newRequest("GET", "/first") + t.Run("/first uses first middleware", func(t *testing.T) { + rw := NewRecorder() + req := newRequest("GET", "/first") - router.ServeHTTP(rw, req) - if rw.Body.String() != first { - t.Fatalf("Middleware did not run: expected %s middleware to write a response (got %s)", first, rw.Body.String()) - } + router.ServeHTTP(rw, req) + if rw.Body.String() != first { + t.Fatalf("Middleware did not run: expected %s middleware to write a response (got %s)", first, rw.Body.String()) + } + }) - rw = NewRecorder() - req = newRequest("GET", "/second") + t.Run("/second uses second middleware", func(t *testing.T) { + rw := NewRecorder() + req := newRequest("GET", "/second") - router.ServeHTTP(rw, req) - if rw.Body.String() != second { - t.Fatalf("Middleware did not run: expected %s middleware to write a response (got %s)", second, rw.Body.String()) - } + router.ServeHTTP(rw, req) + if rw.Body.String() != second { + t.Fatalf("Middleware did not run: expected %s middleware to write a response (got %s)", second, rw.Body.String()) + } + }) - rw = NewRecorder() - req = newRequest("GET", "/second/not-exist") + t.Run("uses not found handler", func(t *testing.T) { + rw := NewRecorder() + req := newRequest("GET", "/second/not-exist") - router.ServeHTTP(rw, req) - if rw.Body.String() != notFound { - t.Fatalf("Notfound handler did not run: expected %s for not-exist, (got %s)", notFound, rw.Body.String()) - } + router.ServeHTTP(rw, req) + if rw.Body.String() != notFound { + t.Fatalf("Notfound handler did not run: expected %s for not-exist, (got %s)", notFound, rw.Body.String()) + } + }) }