yarr/server/server.go
2021-01-27 15:46:38 +00:00

227 lines
4.7 KiB
Go

package server
import (
"context"
"log"
"net/http"
"runtime"
"strings"
"sync/atomic"
"time"
"github.com/nkanaev/yarr/storage"
)
type Handler struct {
Addr string
db *storage.Storage
log *log.Logger
feedQueue chan storage.Feed
queueSize *int32
refreshRate chan int64
// auth
Username string
Password string
// https
CertFile string
KeyFile string
}
func New(db *storage.Storage, logger *log.Logger, addr string) *Handler {
queueSize := int32(0)
return &Handler{
db: db,
log: logger,
feedQueue: make(chan storage.Feed, 3000),
queueSize: &queueSize,
Addr: addr,
refreshRate: make(chan int64),
}
}
func (h *Handler) GetAddr() string {
proto := "http"
if h.CertFile != "" && h.KeyFile != "" {
proto = "https"
}
return proto + "://" + h.Addr + BasePath
}
func (h *Handler) Start() {
h.startJobs()
s := &http.Server{Addr: h.Addr, Handler: h}
var err error
if h.CertFile != "" && h.KeyFile != "" {
err = s.ListenAndServeTLS(h.CertFile, h.KeyFile)
} else {
err = s.ListenAndServe()
}
if err != http.ErrServerClosed {
h.log.Fatal(err)
}
}
func unsafeMethod(method string) bool {
return method == "POST" || method == "PUT" || method == "DELETE"
}
func (h Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
reqPath := req.URL.Path
if BasePath != "" {
if !strings.HasPrefix(reqPath, BasePath) {
rw.WriteHeader(http.StatusNotFound)
return
}
reqPath = strings.TrimPrefix(req.URL.Path, BasePath)
if reqPath == "" {
http.Redirect(rw, req, BasePath+"/", http.StatusFound)
return
}
}
route, vars := getRoute(reqPath)
if route == nil {
rw.WriteHeader(http.StatusNotFound)
return
}
if h.requiresAuth() && !route.manualAuth {
if unsafeMethod(req.Method) && req.Header.Get("X-Requested-By") != "yarr" {
rw.WriteHeader(http.StatusUnauthorized)
return
}
if !userIsAuthenticated(req, h.Username, h.Password) {
rw.WriteHeader(http.StatusUnauthorized)
return
}
}
ctx := context.WithValue(req.Context(), ctxHandler, &h)
ctx = context.WithValue(ctx, ctxVars, vars)
route.handler(rw, req.WithContext(ctx))
}
func (h *Handler) startJobs() {
delTicker := time.NewTicker(time.Hour * 24)
syncSearchChannel := make(chan bool, 10)
var syncSearchTimer *time.Timer // TODO: should this be atomic?
syncSearch := func() {
if syncSearchTimer == nil {
syncSearchTimer = time.AfterFunc(time.Second*2, func() {
syncSearchChannel <- true
})
} else {
syncSearchTimer.Reset(time.Second * 2)
}
}
worker := func() {
for {
select {
case feed := <-h.feedQueue:
items, err := listItems(feed, h.db)
atomic.AddInt32(h.queueSize, -1)
if err != nil {
h.log.Printf("Failed to fetch %s (%d): %s", feed.FeedLink, feed.Id, err)
h.db.SetFeedError(feed.Id, err)
continue
}
h.db.CreateItems(items)
syncSearch()
if !feed.HasIcon {
icon, err := findFavicon(feed.Link, feed.FeedLink)
if icon != nil {
h.db.UpdateFeedIcon(feed.Id, icon)
}
if err != nil {
h.log.Printf("Failed to search favicon for %s (%s): %s", feed.Link, feed.FeedLink, err)
}
}
case <-delTicker.C:
h.db.DeleteOldItems()
case <-syncSearchChannel:
h.db.SyncSearch()
}
}
}
num := runtime.NumCPU() - 1
if num < 1 {
num = 1
}
for i := 0; i < num; i++ {
go worker()
}
go h.db.DeleteOldItems()
go h.db.SyncSearch()
go func() {
var refreshTicker *time.Ticker
refreshTick := make(<-chan time.Time)
for {
select {
case <-refreshTick:
h.fetchAllFeeds()
case val := <-h.refreshRate:
if refreshTicker != nil {
refreshTicker.Stop()
if val == 0 {
refreshTick = make(<-chan time.Time)
}
}
if val > 0 {
refreshTicker = time.NewTicker(time.Duration(val) * time.Minute)
refreshTick = refreshTicker.C
}
}
}
}()
refreshRate := h.db.GetSettingsValueInt64("refresh_rate")
h.refreshRate <- refreshRate
if refreshRate > 0 {
h.fetchAllFeeds()
}
}
func (h Handler) requiresAuth() bool {
return h.Username != "" && h.Password != ""
}
func (h *Handler) fetchAllFeeds() {
h.log.Print("Refreshing all feeds")
h.db.ResetFeedErrors()
for _, feed := range h.db.ListFeeds() {
h.fetchFeed(feed)
}
}
func (h *Handler) fetchFeed(feed storage.Feed) {
atomic.AddInt32(h.queueSize, 1)
h.feedQueue <- feed
}
func Vars(req *http.Request) map[string]string {
if rv := req.Context().Value(ctxVars); rv != nil {
return rv.(map[string]string)
}
return nil
}
func db(req *http.Request) *storage.Storage {
if h := handler(req); h != nil {
return h.db
}
return nil
}
func handler(req *http.Request) *Handler {
return req.Context().Value(ctxHandler).(*Handler)
}
const (
ctxVars = 2
ctxHandler = 3
)