diff --git a/src/router/context.go b/src/router/context.go index 52ccfcd..83424a8 100644 --- a/src/router/context.go +++ b/src/router/context.go @@ -45,3 +45,10 @@ func (c *Context) QueryInt64(key string) (int64, error) { query := c.Req.URL.Query() return strconv.ParseInt(query.Get("page"), 10, 64) } + +func (c *Context) Redirect(url string) { + if url == "" { + url = "/" + } + http.Redirect(c.Out, c.Req, url, http.StatusFound) +} diff --git a/src/server/middleware.go b/src/server/middleware.go new file mode 100644 index 0000000..dda7342 --- /dev/null +++ b/src/server/middleware.go @@ -0,0 +1,57 @@ +package server + +import ( + "net/http" + "strings" + + "github.com/nkanaev/yarr/src/assets" + "github.com/nkanaev/yarr/src/auth" + "github.com/nkanaev/yarr/src/router" +) + +type authMiddleware struct { + username string + password string + basepath string + public string +} + +func (m *authMiddleware) handler(c *router.Context) { + basepath := m.basepath + if basepath == "" { + basepath = "/" + } + + if strings.HasPrefix(c.Req.URL.Path, m.public) { + c.Next() + return + } + if auth.IsAuthenticated(c.Req, m.username, m.password) { + c.Next() + return + } + + if c.Req.URL.Path != basepath { + // TODO: check ajax + c.Out.WriteHeader(http.StatusForbidden) + return + } + + if c.Req.Method == "POST" { + username := c.Req.FormValue("username") + password := c.Req.FormValue("password") + if auth.StringsEqual(username, m.username) && auth.StringsEqual(password, m.password) { + auth.Authenticate(c.Out, m.username, m.password, m.basepath) + c.Redirect(m.basepath) + return + } else { + // TODO: show error + c.Out.Header().Set("Content-Type", "text/html") + assets.Render("login.html", c.Out, nil) + return + } + } + + c.Out.Header().Set("Content-Type", "text/html") + assets.Render("login.html", c.Out, nil) +} diff --git a/src/server/routes.go b/src/server/routes.go index 163d9b6..c53f668 100644 --- a/src/server/routes.go +++ b/src/server/routes.go @@ -19,6 +19,15 @@ func (s *Server) handler() http.Handler { r := router.NewRouter() // TODO: auth, base, security + if s.Username != "" && s.Password != "" { + a := &authMiddleware{ + username: s.Username, + password: s.Password, + basepath: BasePath, + public: BasePath + "/static", + } + r.Use(a.handler) + } r.For("/", s.handleIndex) r.For("/static/*path", s.handleStatic) @@ -42,21 +51,6 @@ func (s *Server) handler() http.Handler { } func (s *Server) handleIndex(c *router.Context) { - if s.requiresAuth() && !auth.IsAuthenticated(c.Req, s.Username, s.Password) { - if c.Req.Method == "POST" { - username := c.Req.FormValue("username") - password := c.Req.FormValue("password") - if auth.StringsEqual(username, s.Username) && auth.StringsEqual(password, s.Password) { - auth.Authenticate(c.Out, username, password, BasePath) - http.Redirect(c.Out, c.Req, c.Req.URL.Path, http.StatusFound) - return - } - } - - c.Out.Header().Set("Content-Type", "text/html") - assets.Render("login.html", c.Out, nil) - return - } c.Out.Header().Set("Content-Type", "text/html") assets.Render("index.html", c.Out, nil) }