router base

This commit is contained in:
Nazar Kanaev 2021-03-17 15:54:05 +00:00
parent c8bc511e04
commit 9bf7f45354
3 changed files with 59 additions and 26 deletions

View File

@ -3,6 +3,7 @@ package router
import ( import (
"net/http" "net/http"
"regexp" "regexp"
"strings"
) )
type Handler func(*Context) type Handler func(*Context)
@ -10,6 +11,7 @@ type Handler func(*Context)
type Router struct { type Router struct {
middle []Handler middle []Handler
routes []Route routes []Route
base string
} }
type Route struct { type Route struct {
@ -17,10 +19,11 @@ type Route struct {
chain []Handler chain []Handler
} }
func NewRouter() *Router { func NewRouter(base string) *Router {
router := &Router{} router := &Router{}
router.middle = make([]Handler, 0) router.middle = make([]Handler, 0)
router.routes = make([]Route, 0) router.routes = make([]Route, 0)
router.base = base
return router return router
} }
@ -37,16 +40,22 @@ func (r *Router) For(path string, handler Handler) {
} }
func (r *Router) resolve(path string) *Route { func (r *Router) resolve(path string) *Route {
for _, r := range r.routes { for _, route := range r.routes {
if r.regex.MatchString(path) { if route.regex.MatchString(path) {
return &r return &route
} }
} }
return nil return nil
} }
func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) { func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
path := req.URL.Path // autoclose open base url
if r.base != "" && r.base == req.URL.Path {
http.Redirect(rw, req, r.base + "/", http.StatusFound)
return
}
path := strings.TrimPrefix(req.URL.Path, r.base)
route := r.resolve(path) route := r.resolve(path)
if route == nil { if route == nil {

View File

@ -8,7 +8,7 @@ import (
func TestRouter(t *testing.T) { func TestRouter(t *testing.T) {
middlecalled := false middlecalled := false
router := NewRouter() router := NewRouter("")
router.Use(func(c *Context) { router.Use(func(c *Context) {
middlecalled = true middlecalled = true
c.Next() c.Next()
@ -20,10 +20,7 @@ func TestRouter(t *testing.T) {
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/hello/world", nil) request := httptest.NewRequest("GET", "/hello/world", nil)
router.ServeHTTP(recorder, request) router.ServeHTTP(recorder, request)
body, err := io.ReadAll(recorder.Result().Body) body, _ := io.ReadAll(recorder.Result().Body)
if err != nil {
t.Error(err)
}
if !middlecalled { if !middlecalled {
t.Error("middleware not called") t.Error("middleware not called")
@ -37,7 +34,7 @@ func TestRouter(t *testing.T) {
} }
func TestRouterPaths(t *testing.T) { func TestRouterPaths(t *testing.T) {
router := NewRouter() router := NewRouter("")
router.For("/path/to/foo", func(c *Context) { router.For("/path/to/foo", func(c *Context) {
c.Out.Write([]byte("foo")) c.Out.Write([]byte("foo"))
}) })
@ -49,17 +46,14 @@ func TestRouterPaths(t *testing.T) {
request := httptest.NewRequest("GET", "/path/to/bar", nil) request := httptest.NewRequest("GET", "/path/to/bar", nil)
router.ServeHTTP(recorder, request) router.ServeHTTP(recorder, request)
body, err := io.ReadAll(recorder.Result().Body) body, _ := io.ReadAll(recorder.Result().Body)
if err != nil {
t.Error(err)
}
if string(body) != "bar" { if string(body) != "bar" {
t.Error("expected 2nd route to be called") t.Error("expected 2nd route to be called")
} }
} }
func TestRouterMiddlewareIntercept(t *testing.T) { func TestRouterMiddlewareIntercept(t *testing.T) {
router := NewRouter() router := NewRouter("")
router.Use(func(c *Context) { router.Use(func(c *Context) {
c.Out.WriteHeader(404) c.Out.WriteHeader(404)
}) })
@ -76,17 +70,14 @@ func TestRouterMiddlewareIntercept(t *testing.T) {
if recorder.Result().StatusCode != 404 { if recorder.Result().StatusCode != 404 {
t.Error("expected 404") t.Error("expected 404")
} }
body, err := io.ReadAll(recorder.Result().Body) body, _ := io.ReadAll(recorder.Result().Body)
if err != nil {
t.Error(err)
}
if len(body) != 0 { if len(body) != 0 {
t.Errorf("expected empty body, got %v", body) t.Errorf("expected empty body, got %v", body)
} }
} }
func TestRouterMiddlewareOrder(t *testing.T) { func TestRouterMiddlewareOrder(t *testing.T) {
router := NewRouter() router := NewRouter("")
router.Use(func(c *Context) { router.Use(func(c *Context) {
c.Out.Write([]byte("foo")) c.Out.Write([]byte("foo"))
@ -113,11 +104,44 @@ func TestRouterMiddlewareOrder(t *testing.T) {
if recorder.Result().StatusCode != 200 { if recorder.Result().StatusCode != 200 {
t.Error("expected 200") t.Error("expected 200")
} }
body, err := io.ReadAll(recorder.Result().Body) body, _ := io.ReadAll(recorder.Result().Body)
if err != nil {
t.Error(err)
}
if string(body) != "foobar!!!" { if string(body) != "foobar!!!" {
t.Errorf("invalid body, got %#v", string(body)) t.Errorf("invalid body, got %#v", string(body))
} }
} }
func TestRouterBase(t *testing.T) {
router := NewRouter("/foo")
router.For("/bar", func(c *Context) {
c.Out.Write([]byte("!!!"))
})
recorder := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/foo/bar", nil)
router.ServeHTTP(recorder, request)
if recorder.Result().StatusCode != 200 {
t.Error("expected 200")
}
body, _ := io.ReadAll(recorder.Result().Body)
if string(body) != "!!!" {
t.Errorf("invalid body, got %#v", string(body))
}
}
func TestRouterBaseRedirect(t *testing.T) {
router := NewRouter("/foo")
router.For("/", func(c *Context) {
c.Out.Write([]byte("!!!"))
})
recorder := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/foo", nil)
router.ServeHTTP(recorder, request)
if recorder.Result().StatusCode != 302 {
t.Errorf("expected 302, got %d", recorder.Result().StatusCode)
}
}

View File

@ -16,7 +16,7 @@ import (
) )
func (s *Server) handler() http.Handler { func (s *Server) handler() http.Handler {
r := router.NewRouter() r := router.NewRouter(BasePath)
// TODO: auth, base, security // TODO: auth, base, security
if s.Username != "" && s.Password != "" { if s.Username != "" && s.Password != "" {