diff --git a/src/main.go b/src/main.go index 1c26e0d..0cc653f 100644 --- a/src/main.go +++ b/src/main.go @@ -21,11 +21,11 @@ func main() { log.SetOutput(os.Stdout) log.SetFlags(log.Ldate | log.Ltime | log.Lshortfile) - var addr, db, authfile, certfile, keyfile string + var addr, db, authfile, certfile, keyfile, basepath string var ver, open bool flag.StringVar(&addr, "addr", "127.0.0.1:7070", "address to run server on") flag.StringVar(&authfile, "auth-file", "", "path to a file containing username:password") - flag.StringVar(&server.BasePath, "base", "", "base path of the service url") + flag.StringVar(&basepath, "base", "", "base path of the service url") flag.StringVar(&certfile, "cert-file", "", "path to cert file for https") flag.StringVar(&keyfile, "key-file", "", "path to key file for https") flag.StringVar(&db, "db", "", "storage file path") @@ -38,14 +38,6 @@ func main() { return } - if server.BasePath != "" && !strings.HasPrefix(server.BasePath, "/") { - server.BasePath = "/" + server.BasePath - } - - if server.BasePath != "" && strings.HasSuffix(server.BasePath, "/") { - server.BasePath = strings.TrimSuffix(server.BasePath, "/") - } - configPath, err := os.UserConfigDir() if err != nil { log.Fatal("Failed to get config dir: ", err) @@ -92,6 +84,10 @@ func main() { srv := server.NewServer(store, addr) + if basepath != "" { + srv.BasePath = "/" + strings.Trim(basepath, "/") + } + if certfile != "" && keyfile != "" { srv.CertFile = certfile srv.KeyFile = keyfile diff --git a/src/server/routes.go b/src/server/routes.go index 36a3928..d7ea6ef 100644 --- a/src/server/routes.go +++ b/src/server/routes.go @@ -19,11 +19,11 @@ import ( ) func (s *Server) handler() http.Handler { - r := router.NewRouter(BasePath) + r := router.NewRouter(s.BasePath) if s.Username != "" && s.Password != "" { a := &auth.Middleware{ - BasePath: BasePath, + BasePath: s.BasePath, Username: s.Username, Password: s.Password, Public: "/static", @@ -61,7 +61,7 @@ func (s *Server) handleIndex(c *router.Context) { func (s *Server) handleStatic(c *router.Context) { // TODO: gzip? - http.StripPrefix(BasePath+"/static/", http.FileServer(http.FS(assets.FS))).ServeHTTP(c.Out, c.Req) + http.StripPrefix(s.BasePath+"/static/", http.FileServer(http.FS(assets.FS))).ServeHTTP(c.Out, c.Req) } func (s *Server) handleStatus(c *router.Context) { @@ -433,6 +433,6 @@ func (s *Server) handlePageCrawl(c *router.Context) { } func (s *Server) handleLogout(c *router.Context) { - auth.Logout(c.Out, BasePath) + auth.Logout(c.Out, s.BasePath) c.Out.WriteHeader(http.StatusNoContent) } diff --git a/src/server/routes_test.go b/src/server/routes_test.go new file mode 100644 index 0000000..f2958ec --- /dev/null +++ b/src/server/routes_test.go @@ -0,0 +1,33 @@ +package server + +import ( + "net/http/httptest" + "testing" +) + +func TestStatic(t *testing.T) { + handler := NewServer(nil, "127.0.0.1:8000").handler() + url := "/static/javascripts/app.js" + + recorder := httptest.NewRecorder() + request := httptest.NewRequest("GET", url, nil) + handler.ServeHTTP(recorder, request) + if recorder.Result().StatusCode != 200 { + t.FailNow() + } +} + +func TestStaticWithBase(t *testing.T) { + server := NewServer(nil, "127.0.0.1:8000") + server.BasePath = "/sub" + + handler := server.handler() + url := "/sub/static/javascripts/app.js" + + recorder := httptest.NewRecorder() + request := httptest.NewRequest("GET", url, nil) + handler.ServeHTTP(recorder, request) + if recorder.Result().StatusCode != 200 { + t.FailNow() + } +} diff --git a/src/server/server.go b/src/server/server.go index 708aded..11b2176 100644 --- a/src/server/server.go +++ b/src/server/server.go @@ -8,12 +8,13 @@ import ( "github.com/nkanaev/yarr/src/worker" ) -var BasePath string = "" - type Server struct { Addr string db *storage.Storage worker *worker.Worker + + BasePath string + // auth Username string Password string @@ -35,7 +36,7 @@ func (h *Server) GetAddr() string { if h.CertFile != "" && h.KeyFile != "" { proto = "https" } - return proto + "://" + h.Addr + BasePath + return proto + "://" + h.Addr + h.BasePath } func (s *Server) Start() {