mirror of
https://github.com/nkanaev/yarr.git
synced 2025-05-24 00:33:14 +00:00
223 lines
4.6 KiB
Go
223 lines
4.6 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
"log"
|
|
"net/http"
|
|
"runtime"
|
|
"strings"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/nkanaev/yarr/storage"
|
|
)
|
|
|
|
type Handler struct {
|
|
Addr string
|
|
db *storage.Storage
|
|
log *log.Logger
|
|
feedQueue chan storage.Feed
|
|
queueSize *int32
|
|
refreshRate chan int64
|
|
// auth
|
|
Username string
|
|
Password string
|
|
// https
|
|
CertFile string
|
|
KeyFile string
|
|
}
|
|
|
|
func New(db *storage.Storage, logger *log.Logger, addr string) *Handler {
|
|
queueSize := int32(0)
|
|
return &Handler{
|
|
db: db,
|
|
log: logger,
|
|
feedQueue: make(chan storage.Feed, 3000),
|
|
queueSize: &queueSize,
|
|
Addr: addr,
|
|
refreshRate: make(chan int64),
|
|
}
|
|
}
|
|
|
|
func (h *Handler) GetAddr() string {
|
|
proto := "http"
|
|
if h.CertFile != "" && h.KeyFile != "" {
|
|
proto = "https"
|
|
}
|
|
return proto + "://" + h.Addr
|
|
}
|
|
|
|
func (h *Handler) Start() {
|
|
h.startJobs()
|
|
s := &http.Server{Addr: h.Addr, Handler: h}
|
|
|
|
var err error
|
|
if h.CertFile != "" && h.KeyFile != "" {
|
|
err = s.ListenAndServeTLS(h.CertFile, h.KeyFile)
|
|
} else {
|
|
err = s.ListenAndServe()
|
|
}
|
|
if err != http.ErrServerClosed {
|
|
h.log.Fatal(err)
|
|
}
|
|
}
|
|
|
|
func unsafeMethod(method string) bool {
|
|
return method == "POST" || method == "PUT" || method == "DELETE"
|
|
}
|
|
|
|
func (h Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
|
reqPath := req.URL.Path
|
|
if BasePath != "" {
|
|
reqPath = strings.TrimPrefix(req.URL.Path, BasePath)
|
|
if reqPath == "" {
|
|
http.Redirect(rw, req, BasePath+"/", http.StatusFound)
|
|
return
|
|
}
|
|
}
|
|
route, vars := getRoute(reqPath)
|
|
if route == nil {
|
|
rw.WriteHeader(http.StatusNotFound)
|
|
return
|
|
}
|
|
|
|
if h.requiresAuth() && !route.manualAuth {
|
|
if unsafeMethod(req.Method) && req.Header.Get("X-Requested-By") != "yarr" {
|
|
rw.WriteHeader(http.StatusUnauthorized)
|
|
return
|
|
}
|
|
if !userIsAuthenticated(req, h.Username, h.Password) {
|
|
rw.WriteHeader(http.StatusUnauthorized)
|
|
return
|
|
}
|
|
}
|
|
|
|
ctx := context.WithValue(req.Context(), ctxHandler, &h)
|
|
ctx = context.WithValue(ctx, ctxVars, vars)
|
|
route.handler(rw, req.WithContext(ctx))
|
|
}
|
|
|
|
func (h *Handler) startJobs() {
|
|
delTicker := time.NewTicker(time.Hour * 24)
|
|
|
|
syncSearchChannel := make(chan bool, 10)
|
|
var syncSearchTimer *time.Timer // TODO: should this be atomic?
|
|
|
|
syncSearch := func() {
|
|
if syncSearchTimer == nil {
|
|
syncSearchTimer = time.AfterFunc(time.Second*2, func() {
|
|
syncSearchChannel <- true
|
|
})
|
|
} else {
|
|
syncSearchTimer.Reset(time.Second * 2)
|
|
}
|
|
}
|
|
|
|
worker := func() {
|
|
for {
|
|
select {
|
|
case feed := <-h.feedQueue:
|
|
items, err := listItems(feed, h.db)
|
|
atomic.AddInt32(h.queueSize, -1)
|
|
if err != nil {
|
|
h.log.Printf("Failed to fetch %s (%d): %s", feed.FeedLink, feed.Id, err)
|
|
h.db.SetFeedError(feed.Id, err)
|
|
continue
|
|
}
|
|
h.db.CreateItems(items)
|
|
syncSearch()
|
|
if !feed.HasIcon {
|
|
icon, err := findFavicon(feed.Link, feed.FeedLink)
|
|
if icon != nil {
|
|
h.db.UpdateFeedIcon(feed.Id, icon)
|
|
}
|
|
if err != nil {
|
|
h.log.Printf("Failed to search favicon for %s (%s): %s", feed.Link, feed.FeedLink, err)
|
|
}
|
|
}
|
|
case <-delTicker.C:
|
|
h.db.DeleteOldItems()
|
|
case <-syncSearchChannel:
|
|
h.db.SyncSearch()
|
|
}
|
|
}
|
|
}
|
|
|
|
num := runtime.NumCPU() - 1
|
|
if num < 1 {
|
|
num = 1
|
|
}
|
|
for i := 0; i < num; i++ {
|
|
go worker()
|
|
}
|
|
go h.db.DeleteOldItems()
|
|
go h.db.SyncSearch()
|
|
|
|
go func() {
|
|
var refreshTicker *time.Ticker
|
|
refreshTick := make(<-chan time.Time)
|
|
for {
|
|
select {
|
|
case <-refreshTick:
|
|
h.fetchAllFeeds()
|
|
case val := <-h.refreshRate:
|
|
if refreshTicker != nil {
|
|
refreshTicker.Stop()
|
|
if val == 0 {
|
|
refreshTick = make(<-chan time.Time)
|
|
}
|
|
}
|
|
if val > 0 {
|
|
refreshTicker = time.NewTicker(time.Duration(val) * time.Minute)
|
|
refreshTick = refreshTicker.C
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
refreshRate := h.db.GetSettingsValueInt64("refresh_rate")
|
|
h.refreshRate <- refreshRate
|
|
if refreshRate > 0 {
|
|
h.fetchAllFeeds()
|
|
}
|
|
}
|
|
|
|
func (h Handler) requiresAuth() bool {
|
|
return h.Username != "" && h.Password != ""
|
|
}
|
|
|
|
func (h *Handler) fetchAllFeeds() {
|
|
h.log.Print("Refreshing all feeds")
|
|
h.db.ResetFeedErrors()
|
|
for _, feed := range h.db.ListFeeds() {
|
|
h.fetchFeed(feed)
|
|
}
|
|
}
|
|
|
|
func (h *Handler) fetchFeed(feed storage.Feed) {
|
|
atomic.AddInt32(h.queueSize, 1)
|
|
h.feedQueue <- feed
|
|
}
|
|
|
|
func Vars(req *http.Request) map[string]string {
|
|
if rv := req.Context().Value(ctxVars); rv != nil {
|
|
return rv.(map[string]string)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func db(req *http.Request) *storage.Storage {
|
|
if h := handler(req); h != nil {
|
|
return h.db
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func handler(req *http.Request) *Handler {
|
|
return req.Context().Value(ctxHandler).(*Handler)
|
|
}
|
|
|
|
const (
|
|
ctxVars = 2
|
|
ctxHandler = 3
|
|
)
|