From 9bf7f453543b6e785bd221b74c8c07a38d6bf1d3 Mon Sep 17 00:00:00 2001 From: Nazar Kanaev Date: Wed, 17 Mar 2021 15:54:05 +0000 Subject: [PATCH] router base --- src/router/router.go | 19 +++++++++--- src/router/router_test.go | 64 +++++++++++++++++++++++++++------------ src/server/routes.go | 2 +- 3 files changed, 59 insertions(+), 26 deletions(-) diff --git a/src/router/router.go b/src/router/router.go index e614457..3f0cc3f 100644 --- a/src/router/router.go +++ b/src/router/router.go @@ -3,6 +3,7 @@ package router import ( "net/http" "regexp" + "strings" ) type Handler func(*Context) @@ -10,6 +11,7 @@ type Handler func(*Context) type Router struct { middle []Handler routes []Route + base string } type Route struct { @@ -17,10 +19,11 @@ type Route struct { chain []Handler } -func NewRouter() *Router { +func NewRouter(base string) *Router { router := &Router{} router.middle = make([]Handler, 0) router.routes = make([]Route, 0) + router.base = base return router } @@ -37,16 +40,22 @@ func (r *Router) For(path string, handler Handler) { } func (r *Router) resolve(path string) *Route { - for _, r := range r.routes { - if r.regex.MatchString(path) { - return &r + for _, route := range r.routes { + if route.regex.MatchString(path) { + return &route } } return nil } func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) { - path := req.URL.Path + // autoclose open base url + if r.base != "" && r.base == req.URL.Path { + http.Redirect(rw, req, r.base + "/", http.StatusFound) + return + } + + path := strings.TrimPrefix(req.URL.Path, r.base) route := r.resolve(path) if route == nil { diff --git a/src/router/router_test.go b/src/router/router_test.go index b395dee..dba5873 100644 --- a/src/router/router_test.go +++ b/src/router/router_test.go @@ -8,7 +8,7 @@ import ( func TestRouter(t *testing.T) { middlecalled := false - router := NewRouter() + router := NewRouter("") router.Use(func(c *Context) { middlecalled = true c.Next() @@ -20,10 +20,7 @@ func TestRouter(t *testing.T) { recorder := httptest.NewRecorder() request := httptest.NewRequest("GET", "/hello/world", nil) router.ServeHTTP(recorder, request) - body, err := io.ReadAll(recorder.Result().Body) - if err != nil { - t.Error(err) - } + body, _ := io.ReadAll(recorder.Result().Body) if !middlecalled { t.Error("middleware not called") @@ -37,7 +34,7 @@ func TestRouter(t *testing.T) { } func TestRouterPaths(t *testing.T) { - router := NewRouter() + router := NewRouter("") router.For("/path/to/foo", func(c *Context) { c.Out.Write([]byte("foo")) }) @@ -49,17 +46,14 @@ func TestRouterPaths(t *testing.T) { request := httptest.NewRequest("GET", "/path/to/bar", nil) router.ServeHTTP(recorder, request) - body, err := io.ReadAll(recorder.Result().Body) - if err != nil { - t.Error(err) - } + body, _ := io.ReadAll(recorder.Result().Body) if string(body) != "bar" { t.Error("expected 2nd route to be called") } } func TestRouterMiddlewareIntercept(t *testing.T) { - router := NewRouter() + router := NewRouter("") router.Use(func(c *Context) { c.Out.WriteHeader(404) }) @@ -76,17 +70,14 @@ func TestRouterMiddlewareIntercept(t *testing.T) { if recorder.Result().StatusCode != 404 { t.Error("expected 404") } - body, err := io.ReadAll(recorder.Result().Body) - if err != nil { - t.Error(err) - } + body, _ := io.ReadAll(recorder.Result().Body) if len(body) != 0 { t.Errorf("expected empty body, got %v", body) } } func TestRouterMiddlewareOrder(t *testing.T) { - router := NewRouter() + router := NewRouter("") router.Use(func(c *Context) { c.Out.Write([]byte("foo")) @@ -113,11 +104,44 @@ func TestRouterMiddlewareOrder(t *testing.T) { if recorder.Result().StatusCode != 200 { t.Error("expected 200") } - body, err := io.ReadAll(recorder.Result().Body) - if err != nil { - t.Error(err) - } + body, _ := io.ReadAll(recorder.Result().Body) if string(body) != "foobar!!!" { t.Errorf("invalid body, got %#v", string(body)) } } + +func TestRouterBase(t *testing.T) { + router := NewRouter("/foo") + router.For("/bar", func(c *Context) { + c.Out.Write([]byte("!!!")) + }) + + recorder := httptest.NewRecorder() + request := httptest.NewRequest("GET", "/foo/bar", nil) + + router.ServeHTTP(recorder, request) + + if recorder.Result().StatusCode != 200 { + t.Error("expected 200") + } + body, _ := io.ReadAll(recorder.Result().Body) + if string(body) != "!!!" { + t.Errorf("invalid body, got %#v", string(body)) + } +} + +func TestRouterBaseRedirect(t *testing.T) { + router := NewRouter("/foo") + router.For("/", func(c *Context) { + c.Out.Write([]byte("!!!")) + }) + + recorder := httptest.NewRecorder() + request := httptest.NewRequest("GET", "/foo", nil) + + router.ServeHTTP(recorder, request) + + if recorder.Result().StatusCode != 302 { + t.Errorf("expected 302, got %d", recorder.Result().StatusCode) + } +} diff --git a/src/server/routes.go b/src/server/routes.go index 35a44b4..34b83ef 100644 --- a/src/server/routes.go +++ b/src/server/routes.go @@ -16,7 +16,7 @@ import ( ) func (s *Server) handler() http.Handler { - r := router.NewRouter() + r := router.NewRouter(BasePath) // TODO: auth, base, security if s.Username != "" && s.Password != "" {