From 66fdbef90b037cdd177e74c9945af26f326f3599 Mon Sep 17 00:00:00 2001 From: Nazar Kanaev Date: Tue, 16 Mar 2021 21:16:27 +0000 Subject: [PATCH] test router --- src/router/router_test.go | 123 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 123 insertions(+) create mode 100644 src/router/router_test.go diff --git a/src/router/router_test.go b/src/router/router_test.go new file mode 100644 index 0000000..b395dee --- /dev/null +++ b/src/router/router_test.go @@ -0,0 +1,123 @@ +package router + +import ( + "io" + "net/http/httptest" + "testing" +) + +func TestRouter(t *testing.T) { + middlecalled := false + router := NewRouter() + router.Use(func(c *Context) { + middlecalled = true + c.Next() + }) + router.For("/hello/:place", func(c *Context) { + c.Out.Write([]byte(c.Vars["place"])) + }) + + recorder := httptest.NewRecorder() + request := httptest.NewRequest("GET", "/hello/world", nil) + router.ServeHTTP(recorder, request) + body, err := io.ReadAll(recorder.Result().Body) + if err != nil { + t.Error(err) + } + + if !middlecalled { + t.Error("middleware not called") + } + if recorder.Result().StatusCode != 200 { + t.Error("expected 200") + } + if string(body) != "world" { + t.Errorf("invalid response body, got %#v", body) + } +} + +func TestRouterPaths(t *testing.T) { + router := NewRouter() + router.For("/path/to/foo", func(c *Context) { + c.Out.Write([]byte("foo")) + }) + router.For("/path/to/bar", func(c *Context) { + c.Out.Write([]byte("bar")) + }) + + recorder := httptest.NewRecorder() + request := httptest.NewRequest("GET", "/path/to/bar", nil) + router.ServeHTTP(recorder, request) + + body, err := io.ReadAll(recorder.Result().Body) + if err != nil { + t.Error(err) + } + if string(body) != "bar" { + t.Error("expected 2nd route to be called") + } +} + +func TestRouterMiddlewareIntercept(t *testing.T) { + router := NewRouter() + router.Use(func(c *Context) { + c.Out.WriteHeader(404) + }) + router.For("/hello/:place", func(c *Context) { + c.Out.WriteHeader(200) + c.Out.Write([]byte(c.Vars["place"])) + }) + + recorder := httptest.NewRecorder() + request := httptest.NewRequest("GET", "/hello/world", nil) + + router.ServeHTTP(recorder, request) + + if recorder.Result().StatusCode != 404 { + t.Error("expected 404") + } + body, err := io.ReadAll(recorder.Result().Body) + if err != nil { + t.Error(err) + } + if len(body) != 0 { + t.Errorf("expected empty body, got %v", body) + } +} + +func TestRouterMiddlewareOrder(t *testing.T) { + router := NewRouter() + + router.Use(func(c *Context) { + c.Out.Write([]byte("foo")) + c.Next() + }) + router.Use(func(c *Context) { + c.Out.Write([]byte("bar")) + c.Next() + }) + router.For("/hello/:place", func(c *Context) { + c.Out.Write([]byte("!!!")) + }) + + router.Use(func(c *Context) { + c.Out.Write([]byte("baz")) + c.Next() + }) + + recorder := httptest.NewRecorder() + request := httptest.NewRequest("GET", "/hello/world", nil) + + router.ServeHTTP(recorder, request) + + if recorder.Result().StatusCode != 200 { + t.Error("expected 200") + } + body, err := io.ReadAll(recorder.Result().Body) + if err != nil { + t.Error(err) + } + if string(body) != "foobar!!!" { + t.Errorf("invalid body, got %#v", string(body)) + } +}