mirror of
https://github.com/nkanaev/yarr.git
synced 2025-05-24 21:19:19 +00:00
router base
This commit is contained in:
parent
c8bc511e04
commit
9bf7f45354
@ -3,6 +3,7 @@ package router
|
|||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Handler func(*Context)
|
type Handler func(*Context)
|
||||||
@ -10,6 +11,7 @@ type Handler func(*Context)
|
|||||||
type Router struct {
|
type Router struct {
|
||||||
middle []Handler
|
middle []Handler
|
||||||
routes []Route
|
routes []Route
|
||||||
|
base string
|
||||||
}
|
}
|
||||||
|
|
||||||
type Route struct {
|
type Route struct {
|
||||||
@ -17,10 +19,11 @@ type Route struct {
|
|||||||
chain []Handler
|
chain []Handler
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRouter() *Router {
|
func NewRouter(base string) *Router {
|
||||||
router := &Router{}
|
router := &Router{}
|
||||||
router.middle = make([]Handler, 0)
|
router.middle = make([]Handler, 0)
|
||||||
router.routes = make([]Route, 0)
|
router.routes = make([]Route, 0)
|
||||||
|
router.base = base
|
||||||
return router
|
return router
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -37,16 +40,22 @@ func (r *Router) For(path string, handler Handler) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *Router) resolve(path string) *Route {
|
func (r *Router) resolve(path string) *Route {
|
||||||
for _, r := range r.routes {
|
for _, route := range r.routes {
|
||||||
if r.regex.MatchString(path) {
|
if route.regex.MatchString(path) {
|
||||||
return &r
|
return &route
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
path := req.URL.Path
|
// autoclose open base url
|
||||||
|
if r.base != "" && r.base == req.URL.Path {
|
||||||
|
http.Redirect(rw, req, r.base + "/", http.StatusFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
path := strings.TrimPrefix(req.URL.Path, r.base)
|
||||||
|
|
||||||
route := r.resolve(path)
|
route := r.resolve(path)
|
||||||
if route == nil {
|
if route == nil {
|
||||||
|
@ -8,7 +8,7 @@ import (
|
|||||||
|
|
||||||
func TestRouter(t *testing.T) {
|
func TestRouter(t *testing.T) {
|
||||||
middlecalled := false
|
middlecalled := false
|
||||||
router := NewRouter()
|
router := NewRouter("")
|
||||||
router.Use(func(c *Context) {
|
router.Use(func(c *Context) {
|
||||||
middlecalled = true
|
middlecalled = true
|
||||||
c.Next()
|
c.Next()
|
||||||
@ -20,10 +20,7 @@ func TestRouter(t *testing.T) {
|
|||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
request := httptest.NewRequest("GET", "/hello/world", nil)
|
request := httptest.NewRequest("GET", "/hello/world", nil)
|
||||||
router.ServeHTTP(recorder, request)
|
router.ServeHTTP(recorder, request)
|
||||||
body, err := io.ReadAll(recorder.Result().Body)
|
body, _ := io.ReadAll(recorder.Result().Body)
|
||||||
if err != nil {
|
|
||||||
t.Error(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !middlecalled {
|
if !middlecalled {
|
||||||
t.Error("middleware not called")
|
t.Error("middleware not called")
|
||||||
@ -37,7 +34,7 @@ func TestRouter(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRouterPaths(t *testing.T) {
|
func TestRouterPaths(t *testing.T) {
|
||||||
router := NewRouter()
|
router := NewRouter("")
|
||||||
router.For("/path/to/foo", func(c *Context) {
|
router.For("/path/to/foo", func(c *Context) {
|
||||||
c.Out.Write([]byte("foo"))
|
c.Out.Write([]byte("foo"))
|
||||||
})
|
})
|
||||||
@ -49,17 +46,14 @@ func TestRouterPaths(t *testing.T) {
|
|||||||
request := httptest.NewRequest("GET", "/path/to/bar", nil)
|
request := httptest.NewRequest("GET", "/path/to/bar", nil)
|
||||||
router.ServeHTTP(recorder, request)
|
router.ServeHTTP(recorder, request)
|
||||||
|
|
||||||
body, err := io.ReadAll(recorder.Result().Body)
|
body, _ := io.ReadAll(recorder.Result().Body)
|
||||||
if err != nil {
|
|
||||||
t.Error(err)
|
|
||||||
}
|
|
||||||
if string(body) != "bar" {
|
if string(body) != "bar" {
|
||||||
t.Error("expected 2nd route to be called")
|
t.Error("expected 2nd route to be called")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRouterMiddlewareIntercept(t *testing.T) {
|
func TestRouterMiddlewareIntercept(t *testing.T) {
|
||||||
router := NewRouter()
|
router := NewRouter("")
|
||||||
router.Use(func(c *Context) {
|
router.Use(func(c *Context) {
|
||||||
c.Out.WriteHeader(404)
|
c.Out.WriteHeader(404)
|
||||||
})
|
})
|
||||||
@ -76,17 +70,14 @@ func TestRouterMiddlewareIntercept(t *testing.T) {
|
|||||||
if recorder.Result().StatusCode != 404 {
|
if recorder.Result().StatusCode != 404 {
|
||||||
t.Error("expected 404")
|
t.Error("expected 404")
|
||||||
}
|
}
|
||||||
body, err := io.ReadAll(recorder.Result().Body)
|
body, _ := io.ReadAll(recorder.Result().Body)
|
||||||
if err != nil {
|
|
||||||
t.Error(err)
|
|
||||||
}
|
|
||||||
if len(body) != 0 {
|
if len(body) != 0 {
|
||||||
t.Errorf("expected empty body, got %v", body)
|
t.Errorf("expected empty body, got %v", body)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRouterMiddlewareOrder(t *testing.T) {
|
func TestRouterMiddlewareOrder(t *testing.T) {
|
||||||
router := NewRouter()
|
router := NewRouter("")
|
||||||
|
|
||||||
router.Use(func(c *Context) {
|
router.Use(func(c *Context) {
|
||||||
c.Out.Write([]byte("foo"))
|
c.Out.Write([]byte("foo"))
|
||||||
@ -113,11 +104,44 @@ func TestRouterMiddlewareOrder(t *testing.T) {
|
|||||||
if recorder.Result().StatusCode != 200 {
|
if recorder.Result().StatusCode != 200 {
|
||||||
t.Error("expected 200")
|
t.Error("expected 200")
|
||||||
}
|
}
|
||||||
body, err := io.ReadAll(recorder.Result().Body)
|
body, _ := io.ReadAll(recorder.Result().Body)
|
||||||
if err != nil {
|
|
||||||
t.Error(err)
|
|
||||||
}
|
|
||||||
if string(body) != "foobar!!!" {
|
if string(body) != "foobar!!!" {
|
||||||
t.Errorf("invalid body, got %#v", string(body))
|
t.Errorf("invalid body, got %#v", string(body))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRouterBase(t *testing.T) {
|
||||||
|
router := NewRouter("/foo")
|
||||||
|
router.For("/bar", func(c *Context) {
|
||||||
|
c.Out.Write([]byte("!!!"))
|
||||||
|
})
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
request := httptest.NewRequest("GET", "/foo/bar", nil)
|
||||||
|
|
||||||
|
router.ServeHTTP(recorder, request)
|
||||||
|
|
||||||
|
if recorder.Result().StatusCode != 200 {
|
||||||
|
t.Error("expected 200")
|
||||||
|
}
|
||||||
|
body, _ := io.ReadAll(recorder.Result().Body)
|
||||||
|
if string(body) != "!!!" {
|
||||||
|
t.Errorf("invalid body, got %#v", string(body))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouterBaseRedirect(t *testing.T) {
|
||||||
|
router := NewRouter("/foo")
|
||||||
|
router.For("/", func(c *Context) {
|
||||||
|
c.Out.Write([]byte("!!!"))
|
||||||
|
})
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
request := httptest.NewRequest("GET", "/foo", nil)
|
||||||
|
|
||||||
|
router.ServeHTTP(recorder, request)
|
||||||
|
|
||||||
|
if recorder.Result().StatusCode != 302 {
|
||||||
|
t.Errorf("expected 302, got %d", recorder.Result().StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -16,7 +16,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (s *Server) handler() http.Handler {
|
func (s *Server) handler() http.Handler {
|
||||||
r := router.NewRouter()
|
r := router.NewRouter(BasePath)
|
||||||
|
|
||||||
// TODO: auth, base, security
|
// TODO: auth, base, security
|
||||||
if s.Username != "" && s.Password != "" {
|
if s.Username != "" && s.Password != "" {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user