mirror of
https://github.com/nkanaev/yarr.git
synced 2025-11-09 19:08:57 +00:00
reorganizing server-related packages
This commit is contained in:
52
src/server/auth/auth.go
Normal file
52
src/server/auth/auth.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"encoding/hex"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func IsAuthenticated(req *http.Request, username, password string) bool {
|
||||
cookie, _ := req.Cookie("auth")
|
||||
if cookie == nil {
|
||||
return false
|
||||
}
|
||||
parts := strings.Split(cookie.Value, ":")
|
||||
if len(parts) != 2 || !StringsEqual(parts[0], username) {
|
||||
return false
|
||||
}
|
||||
return StringsEqual(parts[1], secret(username, password))
|
||||
}
|
||||
|
||||
func Authenticate(rw http.ResponseWriter, username, password, basepath string) {
|
||||
http.SetCookie(rw, &http.Cookie{
|
||||
Name: "auth",
|
||||
Value: username + ":" + secret(username, password),
|
||||
Expires: time.Now().Add(time.Hour * 24 * 7), // 1 week,
|
||||
Path: basepath,
|
||||
})
|
||||
}
|
||||
|
||||
func Logout(rw http.ResponseWriter, basepath string) {
|
||||
http.SetCookie(rw, &http.Cookie{
|
||||
Name: "auth",
|
||||
Value: "",
|
||||
MaxAge: -1,
|
||||
Path: basepath,
|
||||
})
|
||||
}
|
||||
|
||||
func StringsEqual(p1, p2 string) bool {
|
||||
return subtle.ConstantTimeCompare([]byte(p1), []byte(p2)) == 1
|
||||
}
|
||||
|
||||
func secret(msg, key string) string {
|
||||
mac := hmac.New(sha256.New, []byte(key))
|
||||
mac.Write([]byte(msg))
|
||||
src := mac.Sum(nil)
|
||||
return hex.EncodeToString(src)
|
||||
}
|
||||
55
src/server/auth/middleware.go
Normal file
55
src/server/auth/middleware.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/nkanaev/yarr/src/assets"
|
||||
"github.com/nkanaev/yarr/src/server/router"
|
||||
)
|
||||
|
||||
type Middleware struct {
|
||||
Username string
|
||||
Password string
|
||||
BasePath string
|
||||
Public string
|
||||
}
|
||||
|
||||
func unsafeMethod(method string) bool {
|
||||
return method == "POST" || method == "PUT" || method == "DELETE"
|
||||
}
|
||||
|
||||
func (m *Middleware) Handler(c *router.Context) {
|
||||
if strings.HasPrefix(c.Req.URL.Path, m.BasePath+m.Public) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
if IsAuthenticated(c.Req, m.Username, m.Password) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
rootUrl := m.BasePath + "/"
|
||||
|
||||
if c.Req.URL.Path != rootUrl {
|
||||
c.Out.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
if c.Req.Method == "POST" {
|
||||
username := c.Req.FormValue("username")
|
||||
password := c.Req.FormValue("password")
|
||||
if StringsEqual(username, m.Username) && StringsEqual(password, m.Password) {
|
||||
Authenticate(c.Out, m.Username, m.Password, m.BasePath)
|
||||
c.Redirect(rootUrl)
|
||||
return
|
||||
} else {
|
||||
c.HTML(http.StatusOK, assets.Template("login.html"), map[string]string{
|
||||
"username": username,
|
||||
"error": "Invalid username/password",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
c.HTML(http.StatusOK, assets.Template("login.html"), nil)
|
||||
}
|
||||
70
src/server/opml/opml.go
Normal file
70
src/server/opml/opml.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package opml
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"html"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Folder struct {
|
||||
Title string
|
||||
Folders []Folder
|
||||
Feeds []Feed
|
||||
}
|
||||
|
||||
type Feed struct {
|
||||
Title string
|
||||
FeedUrl string
|
||||
SiteUrl string
|
||||
}
|
||||
|
||||
func (f Folder) AllFeeds() []Feed {
|
||||
feeds := make([]Feed, 0)
|
||||
feeds = append(feeds, f.Feeds...)
|
||||
for _, subfolder := range f.Folders {
|
||||
feeds = append(feeds, subfolder.AllFeeds()...)
|
||||
}
|
||||
return feeds
|
||||
}
|
||||
|
||||
var e = html.EscapeString
|
||||
var indent = " "
|
||||
var nl = "\n"
|
||||
|
||||
func (f Folder) outline(level int) string {
|
||||
builder := strings.Builder{}
|
||||
prefix := strings.Repeat(indent, level)
|
||||
|
||||
if level > 0 {
|
||||
builder.WriteString(prefix + fmt.Sprintf(`<outline text="%s">`+nl, e(f.Title)))
|
||||
}
|
||||
for _, folder := range f.Folders {
|
||||
builder.WriteString(folder.outline(level + 1))
|
||||
}
|
||||
for _, feed := range f.Feeds {
|
||||
builder.WriteString(feed.outline(level + 1))
|
||||
}
|
||||
if level > 0 {
|
||||
builder.WriteString(prefix + `</outline>` + nl)
|
||||
}
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
func (f Feed) outline(level int) string {
|
||||
return strings.Repeat(indent, level) + fmt.Sprintf(
|
||||
`<outline type="rss" text="%s" xmlUrl="%s" htmlUrl="%s"/>`+nl,
|
||||
e(f.Title), e(f.FeedUrl), e(f.SiteUrl),
|
||||
)
|
||||
}
|
||||
|
||||
func (f Folder) OPML() string {
|
||||
builder := strings.Builder{}
|
||||
builder.WriteString(`<?xml version="1.0" encoding="UTF-8"?>` + nl)
|
||||
builder.WriteString(`<opml version="1.1">` + nl)
|
||||
builder.WriteString(`<head><title>subscriptions</title></head>` + nl)
|
||||
builder.WriteString(`<body>` + nl)
|
||||
builder.WriteString(f.outline(0))
|
||||
builder.WriteString(`</body>` + nl)
|
||||
builder.WriteString(`</opml>` + nl)
|
||||
return builder.String()
|
||||
}
|
||||
54
src/server/opml/opml_test.go
Normal file
54
src/server/opml/opml_test.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package opml
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestOPML(t *testing.T) {
|
||||
have := (Folder{
|
||||
Title: "",
|
||||
Feeds: []Feed{
|
||||
{
|
||||
Title: "title1",
|
||||
FeedUrl: "https://baz.com/feed.xml",
|
||||
SiteUrl: "https://baz.com/",
|
||||
},
|
||||
},
|
||||
Folders: []Folder{
|
||||
{
|
||||
Title: "sub",
|
||||
Feeds: []Feed{
|
||||
{
|
||||
Title: "subtitle1",
|
||||
FeedUrl: "https://foo.com/feed.xml",
|
||||
SiteUrl: "https://foo.com/",
|
||||
},
|
||||
{
|
||||
Title: "&>",
|
||||
FeedUrl: "https://bar.com/feed.xml",
|
||||
SiteUrl: "https://bar.com/",
|
||||
},
|
||||
},
|
||||
Folders: []Folder{},
|
||||
},
|
||||
},
|
||||
}).OPML()
|
||||
want := `<?xml version="1.0" encoding="UTF-8"?>
|
||||
<opml version="1.1">
|
||||
<head><title>subscriptions</title></head>
|
||||
<body>
|
||||
<outline text="sub">
|
||||
<outline type="rss" text="subtitle1" xmlUrl="https://foo.com/feed.xml" htmlUrl="https://foo.com/"/>
|
||||
<outline type="rss" text="&>" xmlUrl="https://bar.com/feed.xml" htmlUrl="https://bar.com/"/>
|
||||
</outline>
|
||||
<outline type="rss" text="title1" xmlUrl="https://baz.com/feed.xml" htmlUrl="https://baz.com/"/>
|
||||
</body>
|
||||
</opml>
|
||||
`
|
||||
if !reflect.DeepEqual(want, have) {
|
||||
t.Logf("want: %s", want)
|
||||
t.Logf("have: %s", have)
|
||||
t.Fatal("invalid opml")
|
||||
}
|
||||
}
|
||||
49
src/server/opml/read.go
Normal file
49
src/server/opml/read.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package opml
|
||||
|
||||
import (
|
||||
"encoding/xml"
|
||||
"io"
|
||||
)
|
||||
|
||||
type opml struct {
|
||||
XMLName xml.Name `xml:"opml"`
|
||||
Outlines []outline `xml:"body>outline"`
|
||||
}
|
||||
|
||||
type outline struct {
|
||||
Type string `xml:"type,attr,omitempty"`
|
||||
Title string `xml:"text,attr"`
|
||||
FeedUrl string `xml:"xmlUrl,attr,omitempty"`
|
||||
SiteUrl string `xml:"htmlUrl,attr,omitempty"`
|
||||
Outlines []outline `xml:"outline,omitempty"`
|
||||
}
|
||||
|
||||
func buildFolder(title string, outlines []outline) Folder {
|
||||
folder := Folder{Title: title}
|
||||
for _, outline := range outlines {
|
||||
if outline.Type == "rss" {
|
||||
folder.Feeds = append(folder.Feeds, Feed{
|
||||
Title: outline.Title,
|
||||
FeedUrl: outline.FeedUrl,
|
||||
SiteUrl: outline.SiteUrl,
|
||||
})
|
||||
} else {
|
||||
subfolder := buildFolder(outline.Title, outline.Outlines)
|
||||
folder.Folders = append(folder.Folders, subfolder)
|
||||
}
|
||||
}
|
||||
return folder
|
||||
}
|
||||
|
||||
func Parse(r io.Reader) (Folder, error) {
|
||||
val := new(opml)
|
||||
decoder := xml.NewDecoder(r)
|
||||
decoder.Entity = xml.HTMLEntity
|
||||
decoder.Strict = false
|
||||
|
||||
err := decoder.Decode(&val)
|
||||
if err != nil {
|
||||
return Folder{}, err
|
||||
}
|
||||
return buildFolder("", val.Outlines), nil
|
||||
}
|
||||
58
src/server/opml/read_test.go
Normal file
58
src/server/opml/read_test.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package opml
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParse(t *testing.T) {
|
||||
have, _ := Parse(strings.NewReader(`
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<opml version="1.1">
|
||||
<head><title>Subscriptions</title></head>
|
||||
<body>
|
||||
<outline text="sub">
|
||||
<outline type="rss" text="subtitle1" description="sub1"
|
||||
xmlUrl="https://foo.com/feed.xml" htmlUrl="https://foo.com/"/>
|
||||
<outline type="rss" text="&>" description="<>"
|
||||
xmlUrl="https://bar.com/feed.xml" htmlUrl="https://bar.com/"/>
|
||||
</outline>
|
||||
<outline type="rss" text="title1" description="desc1"
|
||||
xmlUrl="https://baz.com/feed.xml" htmlUrl="https://baz.com/"/>
|
||||
</body>
|
||||
</opml>
|
||||
`))
|
||||
want := Folder{
|
||||
Title: "",
|
||||
Feeds: []Feed{
|
||||
{
|
||||
Title: "title1",
|
||||
FeedUrl: "https://baz.com/feed.xml",
|
||||
SiteUrl: "https://baz.com/",
|
||||
},
|
||||
},
|
||||
Folders: []Folder{
|
||||
{
|
||||
Title: "sub",
|
||||
Feeds: []Feed{
|
||||
{
|
||||
Title: "subtitle1",
|
||||
FeedUrl: "https://foo.com/feed.xml",
|
||||
SiteUrl: "https://foo.com/",
|
||||
},
|
||||
{
|
||||
Title: "&>",
|
||||
FeedUrl: "https://bar.com/feed.xml",
|
||||
SiteUrl: "https://bar.com/",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
if !reflect.DeepEqual(want, have) {
|
||||
t.Logf("want: %#v", want)
|
||||
t.Logf("have: %#v", have)
|
||||
t.Fatal("invalid opml")
|
||||
}
|
||||
}
|
||||
61
src/server/router/context.go
Normal file
61
src/server/router/context.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type Context struct {
|
||||
Req *http.Request
|
||||
Out http.ResponseWriter
|
||||
|
||||
Vars map[string]string
|
||||
|
||||
chain []Handler
|
||||
index int
|
||||
}
|
||||
|
||||
func (c *Context) Next() {
|
||||
c.index++
|
||||
c.chain[c.index](c)
|
||||
}
|
||||
|
||||
func (c *Context) JSON(status int, data interface{}) {
|
||||
body, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
c.Out.WriteHeader(status)
|
||||
c.Out.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
c.Out.Write(body)
|
||||
c.Out.Write([]byte("\n"))
|
||||
}
|
||||
|
||||
func (c *Context) HTML(status int, tmpl *template.Template, data interface{}) {
|
||||
c.Out.WriteHeader(status)
|
||||
c.Out.Header().Set("Content-Type", "text/html")
|
||||
tmpl.Execute(c.Out, data)
|
||||
}
|
||||
|
||||
func (c *Context) VarInt64(key string) (int64, error) {
|
||||
if val, ok := c.Vars[key]; ok {
|
||||
return strconv.ParseInt(val, 10, 64)
|
||||
}
|
||||
return 0, fmt.Errorf("no such var: %s", key)
|
||||
}
|
||||
|
||||
func (c *Context) QueryInt64(key string) (int64, error) {
|
||||
query := c.Req.URL.Query()
|
||||
return strconv.ParseInt(query.Get(key), 10, 64)
|
||||
}
|
||||
|
||||
func (c *Context) Redirect(url string) {
|
||||
if url == "" {
|
||||
url = "/"
|
||||
}
|
||||
http.Redirect(c.Out, c.Req, url, http.StatusFound)
|
||||
}
|
||||
24
src/server/router/match.go
Normal file
24
src/server/router/match.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package router
|
||||
|
||||
import "regexp"
|
||||
|
||||
func regexGroups(input string, regex *regexp.Regexp) map[string]string {
|
||||
groups := make(map[string]string)
|
||||
matches := regex.FindStringSubmatchIndex(input)
|
||||
for i, key := range regex.SubexpNames()[1:] {
|
||||
groups[key] = input[matches[i*2+2]:matches[i*2+3]]
|
||||
}
|
||||
return groups
|
||||
}
|
||||
|
||||
func routeRegexp(route string) *regexp.Regexp {
|
||||
chunks := regexp.MustCompile(`[\*\:]\w+`)
|
||||
output := chunks.ReplaceAllStringFunc(route, func(m string) string {
|
||||
if m[0:1] == `*` {
|
||||
return "(?P<" + m[1:] + ">.+)"
|
||||
}
|
||||
return "(?P<" + m[1:] + ">[^/]+)"
|
||||
})
|
||||
output = "^" + output + "$"
|
||||
return regexp.MustCompile(output)
|
||||
}
|
||||
76
src/server/router/match_test.go
Normal file
76
src/server/router/match_test.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRouteRegexpPart(t *testing.T) {
|
||||
in := "/hello/:world"
|
||||
re := routeRegexp(in)
|
||||
|
||||
pos := []string{
|
||||
"/hello/world",
|
||||
"/hello/1234",
|
||||
"/hello/bbc1",
|
||||
}
|
||||
for _, c := range pos {
|
||||
if !re.MatchString(c) {
|
||||
t.Errorf("%v must match %v", in, c)
|
||||
}
|
||||
}
|
||||
|
||||
neg := []string{
|
||||
"/hello",
|
||||
"/hello/world/",
|
||||
"/sub/hello/123",
|
||||
"//hello/123",
|
||||
"/hello/123/hello/",
|
||||
}
|
||||
for _, c := range neg {
|
||||
if re.MatchString(c) {
|
||||
t.Errorf("%q must not match %q", in, c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteRegexpStar(t *testing.T) {
|
||||
in := "/hello/*world"
|
||||
re := routeRegexp(in)
|
||||
|
||||
pos := []string{"/hello/world", "/hello/world/test"}
|
||||
for _, c := range pos {
|
||||
if !re.MatchString(c) {
|
||||
t.Errorf("%q must match %q", in, c)
|
||||
}
|
||||
}
|
||||
|
||||
neg := []string{"/hello/", "/hello"}
|
||||
for _, c := range neg {
|
||||
if re.MatchString(c) {
|
||||
t.Errorf("%v must not match %v", in, c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegexGroupsPart(t *testing.T) {
|
||||
re := routeRegexp("/foo/:bar/1/:baz")
|
||||
|
||||
expect := map[string]string{"bar": "one", "baz": "two"}
|
||||
actual := regexGroups("/foo/one/1/two", re)
|
||||
|
||||
if !reflect.DeepEqual(expect, actual) {
|
||||
t.Errorf("expected: %q, actual: %q", expect, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegexGroupsStar(t *testing.T) {
|
||||
re := routeRegexp("/foo/*bar")
|
||||
|
||||
expect := map[string]string{"bar": "bar/baz/"}
|
||||
actual := regexGroups("/foo/bar/baz/", re)
|
||||
|
||||
if !reflect.DeepEqual(expect, actual) {
|
||||
t.Errorf("expected: %q, actual: %q", expect, actual)
|
||||
}
|
||||
}
|
||||
73
src/server/router/router.go
Normal file
73
src/server/router/router.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Handler func(*Context)
|
||||
|
||||
type Router struct {
|
||||
middle []Handler
|
||||
routes []Route
|
||||
base string
|
||||
}
|
||||
|
||||
type Route struct {
|
||||
regex *regexp.Regexp
|
||||
chain []Handler
|
||||
}
|
||||
|
||||
func NewRouter(base string) *Router {
|
||||
router := &Router{}
|
||||
router.middle = make([]Handler, 0)
|
||||
router.routes = make([]Route, 0)
|
||||
router.base = base
|
||||
return router
|
||||
}
|
||||
|
||||
func (r *Router) Use(h Handler) {
|
||||
r.middle = append(r.middle, h)
|
||||
}
|
||||
|
||||
func (r *Router) For(path string, handler Handler) {
|
||||
x := Route{}
|
||||
x.regex = routeRegexp(path)
|
||||
x.chain = append(r.middle, handler)
|
||||
|
||||
r.routes = append(r.routes, x)
|
||||
}
|
||||
|
||||
func (r *Router) resolve(path string) *Route {
|
||||
for _, route := range r.routes {
|
||||
if route.regex.MatchString(path) {
|
||||
return &route
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
// autoclose open base url
|
||||
if r.base != "" && r.base == req.URL.Path {
|
||||
http.Redirect(rw, req, r.base+"/", http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
path := strings.TrimPrefix(req.URL.Path, r.base)
|
||||
|
||||
route := r.resolve(path)
|
||||
if route == nil {
|
||||
rw.WriteHeader(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
context := &Context{}
|
||||
context.Req = req
|
||||
context.Out = rw
|
||||
context.Vars = regexGroups(path, route.regex)
|
||||
context.index = -1
|
||||
context.chain = route.chain
|
||||
context.Next()
|
||||
}
|
||||
147
src/server/router/router_test.go
Normal file
147
src/server/router/router_test.go
Normal file
@@ -0,0 +1,147 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRouter(t *testing.T) {
|
||||
middlecalled := false
|
||||
router := NewRouter("")
|
||||
router.Use(func(c *Context) {
|
||||
middlecalled = true
|
||||
c.Next()
|
||||
})
|
||||
router.For("/hello/:place", func(c *Context) {
|
||||
c.Out.Write([]byte(c.Vars["place"]))
|
||||
})
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest("GET", "/hello/world", nil)
|
||||
router.ServeHTTP(recorder, request)
|
||||
body, _ := io.ReadAll(recorder.Result().Body)
|
||||
|
||||
if !middlecalled {
|
||||
t.Error("middleware not called")
|
||||
}
|
||||
if recorder.Result().StatusCode != 200 {
|
||||
t.Error("expected 200")
|
||||
}
|
||||
if string(body) != "world" {
|
||||
t.Errorf("invalid response body, got %#v", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouterPaths(t *testing.T) {
|
||||
router := NewRouter("")
|
||||
router.For("/path/to/foo", func(c *Context) {
|
||||
c.Out.Write([]byte("foo"))
|
||||
})
|
||||
router.For("/path/to/bar", func(c *Context) {
|
||||
c.Out.Write([]byte("bar"))
|
||||
})
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest("GET", "/path/to/bar", nil)
|
||||
router.ServeHTTP(recorder, request)
|
||||
|
||||
body, _ := io.ReadAll(recorder.Result().Body)
|
||||
if string(body) != "bar" {
|
||||
t.Error("expected 2nd route to be called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouterMiddlewareIntercept(t *testing.T) {
|
||||
router := NewRouter("")
|
||||
router.Use(func(c *Context) {
|
||||
c.Out.WriteHeader(404)
|
||||
})
|
||||
router.For("/hello/:place", func(c *Context) {
|
||||
c.Out.WriteHeader(200)
|
||||
c.Out.Write([]byte(c.Vars["place"]))
|
||||
})
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest("GET", "/hello/world", nil)
|
||||
|
||||
router.ServeHTTP(recorder, request)
|
||||
|
||||
if recorder.Result().StatusCode != 404 {
|
||||
t.Error("expected 404")
|
||||
}
|
||||
body, _ := io.ReadAll(recorder.Result().Body)
|
||||
if len(body) != 0 {
|
||||
t.Errorf("expected empty body, got %v", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouterMiddlewareOrder(t *testing.T) {
|
||||
router := NewRouter("")
|
||||
|
||||
router.Use(func(c *Context) {
|
||||
c.Out.Write([]byte("foo"))
|
||||
c.Next()
|
||||
})
|
||||
router.Use(func(c *Context) {
|
||||
c.Out.Write([]byte("bar"))
|
||||
c.Next()
|
||||
})
|
||||
router.For("/hello/:place", func(c *Context) {
|
||||
c.Out.Write([]byte("!!!"))
|
||||
})
|
||||
|
||||
router.Use(func(c *Context) {
|
||||
c.Out.Write([]byte("baz"))
|
||||
c.Next()
|
||||
})
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest("GET", "/hello/world", nil)
|
||||
|
||||
router.ServeHTTP(recorder, request)
|
||||
|
||||
if recorder.Result().StatusCode != 200 {
|
||||
t.Error("expected 200")
|
||||
}
|
||||
body, _ := io.ReadAll(recorder.Result().Body)
|
||||
if string(body) != "foobar!!!" {
|
||||
t.Errorf("invalid body, got %#v", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouterBase(t *testing.T) {
|
||||
router := NewRouter("/foo")
|
||||
router.For("/bar", func(c *Context) {
|
||||
c.Out.Write([]byte("!!!"))
|
||||
})
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest("GET", "/foo/bar", nil)
|
||||
|
||||
router.ServeHTTP(recorder, request)
|
||||
|
||||
if recorder.Result().StatusCode != 200 {
|
||||
t.Error("expected 200")
|
||||
}
|
||||
body, _ := io.ReadAll(recorder.Result().Body)
|
||||
if string(body) != "!!!" {
|
||||
t.Errorf("invalid body, got %#v", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouterBaseRedirect(t *testing.T) {
|
||||
router := NewRouter("/foo")
|
||||
router.For("/", func(c *Context) {
|
||||
c.Out.Write([]byte("!!!"))
|
||||
})
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest("GET", "/foo", nil)
|
||||
|
||||
router.ServeHTTP(recorder, request)
|
||||
|
||||
if recorder.Result().StatusCode != 302 {
|
||||
t.Errorf("expected 302, got %d", recorder.Result().StatusCode)
|
||||
}
|
||||
}
|
||||
@@ -8,11 +8,11 @@ import (
|
||||
"reflect"
|
||||
|
||||
"github.com/nkanaev/yarr/src/assets"
|
||||
"github.com/nkanaev/yarr/src/auth"
|
||||
"github.com/nkanaev/yarr/src/content/readability"
|
||||
"github.com/nkanaev/yarr/src/content/sanitizer"
|
||||
"github.com/nkanaev/yarr/src/opml"
|
||||
"github.com/nkanaev/yarr/src/router"
|
||||
"github.com/nkanaev/yarr/src/server/router"
|
||||
"github.com/nkanaev/yarr/src/server/auth"
|
||||
"github.com/nkanaev/yarr/src/server/opml"
|
||||
"github.com/nkanaev/yarr/src/storage"
|
||||
"github.com/nkanaev/yarr/src/worker"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user