reorganizing server-related packages

This commit is contained in:
Nazar Kanaev
2021-04-01 00:24:18 +01:00
parent b04e8c1e93
commit 528df7fb4a
12 changed files with 4 additions and 4 deletions

52
src/server/auth/auth.go Normal file
View File

@@ -0,0 +1,52 @@
package auth
import (
"crypto/hmac"
"crypto/sha256"
"crypto/subtle"
"encoding/hex"
"net/http"
"strings"
"time"
)
func IsAuthenticated(req *http.Request, username, password string) bool {
cookie, _ := req.Cookie("auth")
if cookie == nil {
return false
}
parts := strings.Split(cookie.Value, ":")
if len(parts) != 2 || !StringsEqual(parts[0], username) {
return false
}
return StringsEqual(parts[1], secret(username, password))
}
func Authenticate(rw http.ResponseWriter, username, password, basepath string) {
http.SetCookie(rw, &http.Cookie{
Name: "auth",
Value: username + ":" + secret(username, password),
Expires: time.Now().Add(time.Hour * 24 * 7), // 1 week,
Path: basepath,
})
}
func Logout(rw http.ResponseWriter, basepath string) {
http.SetCookie(rw, &http.Cookie{
Name: "auth",
Value: "",
MaxAge: -1,
Path: basepath,
})
}
func StringsEqual(p1, p2 string) bool {
return subtle.ConstantTimeCompare([]byte(p1), []byte(p2)) == 1
}
func secret(msg, key string) string {
mac := hmac.New(sha256.New, []byte(key))
mac.Write([]byte(msg))
src := mac.Sum(nil)
return hex.EncodeToString(src)
}

View File

@@ -0,0 +1,55 @@
package auth
import (
"net/http"
"strings"
"github.com/nkanaev/yarr/src/assets"
"github.com/nkanaev/yarr/src/server/router"
)
type Middleware struct {
Username string
Password string
BasePath string
Public string
}
func unsafeMethod(method string) bool {
return method == "POST" || method == "PUT" || method == "DELETE"
}
func (m *Middleware) Handler(c *router.Context) {
if strings.HasPrefix(c.Req.URL.Path, m.BasePath+m.Public) {
c.Next()
return
}
if IsAuthenticated(c.Req, m.Username, m.Password) {
c.Next()
return
}
rootUrl := m.BasePath + "/"
if c.Req.URL.Path != rootUrl {
c.Out.WriteHeader(http.StatusUnauthorized)
return
}
if c.Req.Method == "POST" {
username := c.Req.FormValue("username")
password := c.Req.FormValue("password")
if StringsEqual(username, m.Username) && StringsEqual(password, m.Password) {
Authenticate(c.Out, m.Username, m.Password, m.BasePath)
c.Redirect(rootUrl)
return
} else {
c.HTML(http.StatusOK, assets.Template("login.html"), map[string]string{
"username": username,
"error": "Invalid username/password",
})
return
}
}
c.HTML(http.StatusOK, assets.Template("login.html"), nil)
}

70
src/server/opml/opml.go Normal file
View File

@@ -0,0 +1,70 @@
package opml
import (
"fmt"
"html"
"strings"
)
type Folder struct {
Title string
Folders []Folder
Feeds []Feed
}
type Feed struct {
Title string
FeedUrl string
SiteUrl string
}
func (f Folder) AllFeeds() []Feed {
feeds := make([]Feed, 0)
feeds = append(feeds, f.Feeds...)
for _, subfolder := range f.Folders {
feeds = append(feeds, subfolder.AllFeeds()...)
}
return feeds
}
var e = html.EscapeString
var indent = " "
var nl = "\n"
func (f Folder) outline(level int) string {
builder := strings.Builder{}
prefix := strings.Repeat(indent, level)
if level > 0 {
builder.WriteString(prefix + fmt.Sprintf(`<outline text="%s">`+nl, e(f.Title)))
}
for _, folder := range f.Folders {
builder.WriteString(folder.outline(level + 1))
}
for _, feed := range f.Feeds {
builder.WriteString(feed.outline(level + 1))
}
if level > 0 {
builder.WriteString(prefix + `</outline>` + nl)
}
return builder.String()
}
func (f Feed) outline(level int) string {
return strings.Repeat(indent, level) + fmt.Sprintf(
`<outline type="rss" text="%s" xmlUrl="%s" htmlUrl="%s"/>`+nl,
e(f.Title), e(f.FeedUrl), e(f.SiteUrl),
)
}
func (f Folder) OPML() string {
builder := strings.Builder{}
builder.WriteString(`<?xml version="1.0" encoding="UTF-8"?>` + nl)
builder.WriteString(`<opml version="1.1">` + nl)
builder.WriteString(`<head><title>subscriptions</title></head>` + nl)
builder.WriteString(`<body>` + nl)
builder.WriteString(f.outline(0))
builder.WriteString(`</body>` + nl)
builder.WriteString(`</opml>` + nl)
return builder.String()
}

View File

@@ -0,0 +1,54 @@
package opml
import (
"reflect"
"testing"
)
func TestOPML(t *testing.T) {
have := (Folder{
Title: "",
Feeds: []Feed{
{
Title: "title1",
FeedUrl: "https://baz.com/feed.xml",
SiteUrl: "https://baz.com/",
},
},
Folders: []Folder{
{
Title: "sub",
Feeds: []Feed{
{
Title: "subtitle1",
FeedUrl: "https://foo.com/feed.xml",
SiteUrl: "https://foo.com/",
},
{
Title: "&>",
FeedUrl: "https://bar.com/feed.xml",
SiteUrl: "https://bar.com/",
},
},
Folders: []Folder{},
},
},
}).OPML()
want := `<?xml version="1.0" encoding="UTF-8"?>
<opml version="1.1">
<head><title>subscriptions</title></head>
<body>
<outline text="sub">
<outline type="rss" text="subtitle1" xmlUrl="https://foo.com/feed.xml" htmlUrl="https://foo.com/"/>
<outline type="rss" text="&amp;&gt;" xmlUrl="https://bar.com/feed.xml" htmlUrl="https://bar.com/"/>
</outline>
<outline type="rss" text="title1" xmlUrl="https://baz.com/feed.xml" htmlUrl="https://baz.com/"/>
</body>
</opml>
`
if !reflect.DeepEqual(want, have) {
t.Logf("want: %s", want)
t.Logf("have: %s", have)
t.Fatal("invalid opml")
}
}

49
src/server/opml/read.go Normal file
View File

@@ -0,0 +1,49 @@
package opml
import (
"encoding/xml"
"io"
)
type opml struct {
XMLName xml.Name `xml:"opml"`
Outlines []outline `xml:"body>outline"`
}
type outline struct {
Type string `xml:"type,attr,omitempty"`
Title string `xml:"text,attr"`
FeedUrl string `xml:"xmlUrl,attr,omitempty"`
SiteUrl string `xml:"htmlUrl,attr,omitempty"`
Outlines []outline `xml:"outline,omitempty"`
}
func buildFolder(title string, outlines []outline) Folder {
folder := Folder{Title: title}
for _, outline := range outlines {
if outline.Type == "rss" {
folder.Feeds = append(folder.Feeds, Feed{
Title: outline.Title,
FeedUrl: outline.FeedUrl,
SiteUrl: outline.SiteUrl,
})
} else {
subfolder := buildFolder(outline.Title, outline.Outlines)
folder.Folders = append(folder.Folders, subfolder)
}
}
return folder
}
func Parse(r io.Reader) (Folder, error) {
val := new(opml)
decoder := xml.NewDecoder(r)
decoder.Entity = xml.HTMLEntity
decoder.Strict = false
err := decoder.Decode(&val)
if err != nil {
return Folder{}, err
}
return buildFolder("", val.Outlines), nil
}

View File

@@ -0,0 +1,58 @@
package opml
import (
"reflect"
"strings"
"testing"
)
func TestParse(t *testing.T) {
have, _ := Parse(strings.NewReader(`
<?xml version="1.0" encoding="UTF-8"?>
<opml version="1.1">
<head><title>Subscriptions</title></head>
<body>
<outline text="sub">
<outline type="rss" text="subtitle1" description="sub1"
xmlUrl="https://foo.com/feed.xml" htmlUrl="https://foo.com/"/>
<outline type="rss" text="&amp;&gt;" description="&lt;&gt;"
xmlUrl="https://bar.com/feed.xml" htmlUrl="https://bar.com/"/>
</outline>
<outline type="rss" text="title1" description="desc1"
xmlUrl="https://baz.com/feed.xml" htmlUrl="https://baz.com/"/>
</body>
</opml>
`))
want := Folder{
Title: "",
Feeds: []Feed{
{
Title: "title1",
FeedUrl: "https://baz.com/feed.xml",
SiteUrl: "https://baz.com/",
},
},
Folders: []Folder{
{
Title: "sub",
Feeds: []Feed{
{
Title: "subtitle1",
FeedUrl: "https://foo.com/feed.xml",
SiteUrl: "https://foo.com/",
},
{
Title: "&>",
FeedUrl: "https://bar.com/feed.xml",
SiteUrl: "https://bar.com/",
},
},
},
},
}
if !reflect.DeepEqual(want, have) {
t.Logf("want: %#v", want)
t.Logf("have: %#v", have)
t.Fatal("invalid opml")
}
}

View File

@@ -0,0 +1,61 @@
package router
import (
"encoding/json"
"fmt"
"html/template"
"log"
"net/http"
"strconv"
)
type Context struct {
Req *http.Request
Out http.ResponseWriter
Vars map[string]string
chain []Handler
index int
}
func (c *Context) Next() {
c.index++
c.chain[c.index](c)
}
func (c *Context) JSON(status int, data interface{}) {
body, err := json.Marshal(data)
if err != nil {
log.Fatal(err)
}
c.Out.WriteHeader(status)
c.Out.Header().Set("Content-Type", "application/json; charset=utf-8")
c.Out.Write(body)
c.Out.Write([]byte("\n"))
}
func (c *Context) HTML(status int, tmpl *template.Template, data interface{}) {
c.Out.WriteHeader(status)
c.Out.Header().Set("Content-Type", "text/html")
tmpl.Execute(c.Out, data)
}
func (c *Context) VarInt64(key string) (int64, error) {
if val, ok := c.Vars[key]; ok {
return strconv.ParseInt(val, 10, 64)
}
return 0, fmt.Errorf("no such var: %s", key)
}
func (c *Context) QueryInt64(key string) (int64, error) {
query := c.Req.URL.Query()
return strconv.ParseInt(query.Get(key), 10, 64)
}
func (c *Context) Redirect(url string) {
if url == "" {
url = "/"
}
http.Redirect(c.Out, c.Req, url, http.StatusFound)
}

View File

@@ -0,0 +1,24 @@
package router
import "regexp"
func regexGroups(input string, regex *regexp.Regexp) map[string]string {
groups := make(map[string]string)
matches := regex.FindStringSubmatchIndex(input)
for i, key := range regex.SubexpNames()[1:] {
groups[key] = input[matches[i*2+2]:matches[i*2+3]]
}
return groups
}
func routeRegexp(route string) *regexp.Regexp {
chunks := regexp.MustCompile(`[\*\:]\w+`)
output := chunks.ReplaceAllStringFunc(route, func(m string) string {
if m[0:1] == `*` {
return "(?P<" + m[1:] + ">.+)"
}
return "(?P<" + m[1:] + ">[^/]+)"
})
output = "^" + output + "$"
return regexp.MustCompile(output)
}

View File

@@ -0,0 +1,76 @@
package router
import (
"reflect"
"testing"
)
func TestRouteRegexpPart(t *testing.T) {
in := "/hello/:world"
re := routeRegexp(in)
pos := []string{
"/hello/world",
"/hello/1234",
"/hello/bbc1",
}
for _, c := range pos {
if !re.MatchString(c) {
t.Errorf("%v must match %v", in, c)
}
}
neg := []string{
"/hello",
"/hello/world/",
"/sub/hello/123",
"//hello/123",
"/hello/123/hello/",
}
for _, c := range neg {
if re.MatchString(c) {
t.Errorf("%q must not match %q", in, c)
}
}
}
func TestRouteRegexpStar(t *testing.T) {
in := "/hello/*world"
re := routeRegexp(in)
pos := []string{"/hello/world", "/hello/world/test"}
for _, c := range pos {
if !re.MatchString(c) {
t.Errorf("%q must match %q", in, c)
}
}
neg := []string{"/hello/", "/hello"}
for _, c := range neg {
if re.MatchString(c) {
t.Errorf("%v must not match %v", in, c)
}
}
}
func TestRegexGroupsPart(t *testing.T) {
re := routeRegexp("/foo/:bar/1/:baz")
expect := map[string]string{"bar": "one", "baz": "two"}
actual := regexGroups("/foo/one/1/two", re)
if !reflect.DeepEqual(expect, actual) {
t.Errorf("expected: %q, actual: %q", expect, actual)
}
}
func TestRegexGroupsStar(t *testing.T) {
re := routeRegexp("/foo/*bar")
expect := map[string]string{"bar": "bar/baz/"}
actual := regexGroups("/foo/bar/baz/", re)
if !reflect.DeepEqual(expect, actual) {
t.Errorf("expected: %q, actual: %q", expect, actual)
}
}

View File

@@ -0,0 +1,73 @@
package router
import (
"net/http"
"regexp"
"strings"
)
type Handler func(*Context)
type Router struct {
middle []Handler
routes []Route
base string
}
type Route struct {
regex *regexp.Regexp
chain []Handler
}
func NewRouter(base string) *Router {
router := &Router{}
router.middle = make([]Handler, 0)
router.routes = make([]Route, 0)
router.base = base
return router
}
func (r *Router) Use(h Handler) {
r.middle = append(r.middle, h)
}
func (r *Router) For(path string, handler Handler) {
x := Route{}
x.regex = routeRegexp(path)
x.chain = append(r.middle, handler)
r.routes = append(r.routes, x)
}
func (r *Router) resolve(path string) *Route {
for _, route := range r.routes {
if route.regex.MatchString(path) {
return &route
}
}
return nil
}
func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// autoclose open base url
if r.base != "" && r.base == req.URL.Path {
http.Redirect(rw, req, r.base+"/", http.StatusFound)
return
}
path := strings.TrimPrefix(req.URL.Path, r.base)
route := r.resolve(path)
if route == nil {
rw.WriteHeader(http.StatusNotFound)
return
}
context := &Context{}
context.Req = req
context.Out = rw
context.Vars = regexGroups(path, route.regex)
context.index = -1
context.chain = route.chain
context.Next()
}

View File

@@ -0,0 +1,147 @@
package router
import (
"io"
"net/http/httptest"
"testing"
)
func TestRouter(t *testing.T) {
middlecalled := false
router := NewRouter("")
router.Use(func(c *Context) {
middlecalled = true
c.Next()
})
router.For("/hello/:place", func(c *Context) {
c.Out.Write([]byte(c.Vars["place"]))
})
recorder := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/hello/world", nil)
router.ServeHTTP(recorder, request)
body, _ := io.ReadAll(recorder.Result().Body)
if !middlecalled {
t.Error("middleware not called")
}
if recorder.Result().StatusCode != 200 {
t.Error("expected 200")
}
if string(body) != "world" {
t.Errorf("invalid response body, got %#v", body)
}
}
func TestRouterPaths(t *testing.T) {
router := NewRouter("")
router.For("/path/to/foo", func(c *Context) {
c.Out.Write([]byte("foo"))
})
router.For("/path/to/bar", func(c *Context) {
c.Out.Write([]byte("bar"))
})
recorder := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/path/to/bar", nil)
router.ServeHTTP(recorder, request)
body, _ := io.ReadAll(recorder.Result().Body)
if string(body) != "bar" {
t.Error("expected 2nd route to be called")
}
}
func TestRouterMiddlewareIntercept(t *testing.T) {
router := NewRouter("")
router.Use(func(c *Context) {
c.Out.WriteHeader(404)
})
router.For("/hello/:place", func(c *Context) {
c.Out.WriteHeader(200)
c.Out.Write([]byte(c.Vars["place"]))
})
recorder := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/hello/world", nil)
router.ServeHTTP(recorder, request)
if recorder.Result().StatusCode != 404 {
t.Error("expected 404")
}
body, _ := io.ReadAll(recorder.Result().Body)
if len(body) != 0 {
t.Errorf("expected empty body, got %v", body)
}
}
func TestRouterMiddlewareOrder(t *testing.T) {
router := NewRouter("")
router.Use(func(c *Context) {
c.Out.Write([]byte("foo"))
c.Next()
})
router.Use(func(c *Context) {
c.Out.Write([]byte("bar"))
c.Next()
})
router.For("/hello/:place", func(c *Context) {
c.Out.Write([]byte("!!!"))
})
router.Use(func(c *Context) {
c.Out.Write([]byte("baz"))
c.Next()
})
recorder := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/hello/world", nil)
router.ServeHTTP(recorder, request)
if recorder.Result().StatusCode != 200 {
t.Error("expected 200")
}
body, _ := io.ReadAll(recorder.Result().Body)
if string(body) != "foobar!!!" {
t.Errorf("invalid body, got %#v", string(body))
}
}
func TestRouterBase(t *testing.T) {
router := NewRouter("/foo")
router.For("/bar", func(c *Context) {
c.Out.Write([]byte("!!!"))
})
recorder := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/foo/bar", nil)
router.ServeHTTP(recorder, request)
if recorder.Result().StatusCode != 200 {
t.Error("expected 200")
}
body, _ := io.ReadAll(recorder.Result().Body)
if string(body) != "!!!" {
t.Errorf("invalid body, got %#v", string(body))
}
}
func TestRouterBaseRedirect(t *testing.T) {
router := NewRouter("/foo")
router.For("/", func(c *Context) {
c.Out.Write([]byte("!!!"))
})
recorder := httptest.NewRecorder()
request := httptest.NewRequest("GET", "/foo", nil)
router.ServeHTTP(recorder, request)
if recorder.Result().StatusCode != 302 {
t.Errorf("expected 302, got %d", recorder.Result().StatusCode)
}
}

View File

@@ -8,11 +8,11 @@ import (
"reflect"
"github.com/nkanaev/yarr/src/assets"
"github.com/nkanaev/yarr/src/auth"
"github.com/nkanaev/yarr/src/content/readability"
"github.com/nkanaev/yarr/src/content/sanitizer"
"github.com/nkanaev/yarr/src/opml"
"github.com/nkanaev/yarr/src/router"
"github.com/nkanaev/yarr/src/server/router"
"github.com/nkanaev/yarr/src/server/auth"
"github.com/nkanaev/yarr/src/server/opml"
"github.com/nkanaev/yarr/src/storage"
"github.com/nkanaev/yarr/src/worker"
)