rewrite basepath

This commit is contained in:
Nazar Kanaev 2021-04-02 15:41:08 +01:00
parent 1cba53f7fb
commit d7ba203f28
4 changed files with 47 additions and 17 deletions

View File

@ -21,11 +21,11 @@ func main() {
log.SetOutput(os.Stdout) log.SetOutput(os.Stdout)
log.SetFlags(log.Ldate | log.Ltime | log.Lshortfile) log.SetFlags(log.Ldate | log.Ltime | log.Lshortfile)
var addr, db, authfile, certfile, keyfile string var addr, db, authfile, certfile, keyfile, basepath string
var ver, open bool var ver, open bool
flag.StringVar(&addr, "addr", "127.0.0.1:7070", "address to run server on") flag.StringVar(&addr, "addr", "127.0.0.1:7070", "address to run server on")
flag.StringVar(&authfile, "auth-file", "", "path to a file containing username:password") flag.StringVar(&authfile, "auth-file", "", "path to a file containing username:password")
flag.StringVar(&server.BasePath, "base", "", "base path of the service url") flag.StringVar(&basepath, "base", "", "base path of the service url")
flag.StringVar(&certfile, "cert-file", "", "path to cert file for https") flag.StringVar(&certfile, "cert-file", "", "path to cert file for https")
flag.StringVar(&keyfile, "key-file", "", "path to key file for https") flag.StringVar(&keyfile, "key-file", "", "path to key file for https")
flag.StringVar(&db, "db", "", "storage file path") flag.StringVar(&db, "db", "", "storage file path")
@ -38,14 +38,6 @@ func main() {
return return
} }
if server.BasePath != "" && !strings.HasPrefix(server.BasePath, "/") {
server.BasePath = "/" + server.BasePath
}
if server.BasePath != "" && strings.HasSuffix(server.BasePath, "/") {
server.BasePath = strings.TrimSuffix(server.BasePath, "/")
}
configPath, err := os.UserConfigDir() configPath, err := os.UserConfigDir()
if err != nil { if err != nil {
log.Fatal("Failed to get config dir: ", err) log.Fatal("Failed to get config dir: ", err)
@ -92,6 +84,10 @@ func main() {
srv := server.NewServer(store, addr) srv := server.NewServer(store, addr)
if basepath != "" {
srv.BasePath = "/" + strings.Trim(basepath, "/")
}
if certfile != "" && keyfile != "" { if certfile != "" && keyfile != "" {
srv.CertFile = certfile srv.CertFile = certfile
srv.KeyFile = keyfile srv.KeyFile = keyfile

View File

@ -19,11 +19,11 @@ import (
) )
func (s *Server) handler() http.Handler { func (s *Server) handler() http.Handler {
r := router.NewRouter(BasePath) r := router.NewRouter(s.BasePath)
if s.Username != "" && s.Password != "" { if s.Username != "" && s.Password != "" {
a := &auth.Middleware{ a := &auth.Middleware{
BasePath: BasePath, BasePath: s.BasePath,
Username: s.Username, Username: s.Username,
Password: s.Password, Password: s.Password,
Public: "/static", Public: "/static",
@ -61,7 +61,7 @@ func (s *Server) handleIndex(c *router.Context) {
func (s *Server) handleStatic(c *router.Context) { func (s *Server) handleStatic(c *router.Context) {
// TODO: gzip? // TODO: gzip?
http.StripPrefix(BasePath+"/static/", http.FileServer(http.FS(assets.FS))).ServeHTTP(c.Out, c.Req) http.StripPrefix(s.BasePath+"/static/", http.FileServer(http.FS(assets.FS))).ServeHTTP(c.Out, c.Req)
} }
func (s *Server) handleStatus(c *router.Context) { func (s *Server) handleStatus(c *router.Context) {
@ -433,6 +433,6 @@ func (s *Server) handlePageCrawl(c *router.Context) {
} }
func (s *Server) handleLogout(c *router.Context) { func (s *Server) handleLogout(c *router.Context) {
auth.Logout(c.Out, BasePath) auth.Logout(c.Out, s.BasePath)
c.Out.WriteHeader(http.StatusNoContent) c.Out.WriteHeader(http.StatusNoContent)
} }

33
src/server/routes_test.go Normal file
View File

@ -0,0 +1,33 @@
package server
import (
"net/http/httptest"
"testing"
)
func TestStatic(t *testing.T) {
handler := NewServer(nil, "127.0.0.1:8000").handler()
url := "/static/javascripts/app.js"
recorder := httptest.NewRecorder()
request := httptest.NewRequest("GET", url, nil)
handler.ServeHTTP(recorder, request)
if recorder.Result().StatusCode != 200 {
t.FailNow()
}
}
func TestStaticWithBase(t *testing.T) {
server := NewServer(nil, "127.0.0.1:8000")
server.BasePath = "/sub"
handler := server.handler()
url := "/sub/static/javascripts/app.js"
recorder := httptest.NewRecorder()
request := httptest.NewRequest("GET", url, nil)
handler.ServeHTTP(recorder, request)
if recorder.Result().StatusCode != 200 {
t.FailNow()
}
}

View File

@ -8,12 +8,13 @@ import (
"github.com/nkanaev/yarr/src/worker" "github.com/nkanaev/yarr/src/worker"
) )
var BasePath string = ""
type Server struct { type Server struct {
Addr string Addr string
db *storage.Storage db *storage.Storage
worker *worker.Worker worker *worker.Worker
BasePath string
// auth // auth
Username string Username string
Password string Password string
@ -35,7 +36,7 @@ func (h *Server) GetAddr() string {
if h.CertFile != "" && h.KeyFile != "" { if h.CertFile != "" && h.KeyFile != "" {
proto = "https" proto = "https"
} }
return proto + "://" + h.Addr + BasePath return proto + "://" + h.Addr + h.BasePath
} }
func (s *Server) Start() { func (s *Server) Start() {