diff --git a/server/crawler.go b/server/crawler.go index 832fd5b..8929cb3 100644 --- a/server/crawler.go +++ b/server/crawler.go @@ -10,6 +10,7 @@ import ( "io/ioutil" "net/http" "net/url" + "time" ) type FeedSource struct { @@ -31,6 +32,22 @@ const feedLinks = ` a:contains("FEED") ` +type Client struct { + httpClient *http.Client + userAgent string +} + +func (c *Client) get(url string) (*http.Response, error) { + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, err + } + req.Header.Set("User-Agent", c.userAgent) + return c.httpClient.Do(req) +} + +var defaultClient *Client + func searchFeedLinks(html []byte, siteurl string) ([]FeedSource, error) { sources := make([]FeedSource, 0, 0) @@ -57,15 +74,14 @@ func searchFeedLinks(html []byte, siteurl string) ([]FeedSource, error) { return sources, nil } -func discoverFeed(url, userAgent string) (*gofeed.Feed, *[]FeedSource, error) { +func discoverFeed(url string) (*gofeed.Feed, *[]FeedSource, error) { // Query URL - feedreq, _ := http.NewRequest("GET", url, nil) - feedreq.Header.Set("user-agent", userAgent) - feedclient := &http.Client{} - res, err := feedclient.Do(feedreq) + res, err := defaultClient.get(url) if err != nil { return nil, nil, err - } else if res.StatusCode != 200 { + } + defer res.Body.Close() + if res.StatusCode != 200 { errmsg := fmt.Sprintf("Failed to fetch feed %s (status: %d)", url, res.StatusCode) return nil, nil, errors.New(errmsg) } @@ -95,7 +111,7 @@ func discoverFeed(url, userAgent string) (*gofeed.Feed, *[]FeedSource, error) { if sources[0].Url == url { return nil, nil, errors.New("Recursion!") } - return discoverFeed(sources[0].Url, userAgent) + return discoverFeed(sources[0].Url) } return nil, &sources, nil } @@ -134,8 +150,6 @@ func findFavicon(websiteUrl, feedUrl string) (*[]byte, error) { candidateUrls = append(candidateUrls, c) } - client := http.Client{} - imageTypes := [4]string{ "image/x-icon", "image/png", @@ -143,7 +157,12 @@ func findFavicon(websiteUrl, feedUrl string) (*[]byte, error) { "image/gif", } for _, url := range candidateUrls { - if res, err := client.Get(url); err == nil && res.StatusCode == 200 { + res, err := defaultClient.get(url) + if err != nil { + continue + } + defer res.Body.Close() + if res.StatusCode == 200 { if content, err := ioutil.ReadAll(res.Body); err == nil { ctype := http.DetectContentType(content) for _, itype := range imageTypes { @@ -186,10 +205,22 @@ func convertItems(items []*gofeed.Item, feed storage.Feed) []storage.Item { } func listItems(f storage.Feed) ([]storage.Item, error) { - fp := gofeed.NewParser() - feed, err := fp.ParseURL(f.FeedLink) + res, err := defaultClient.get(f.FeedLink) + if err != nil { + return nil, err + } + defer res.Body.Close() + feedparser := gofeed.NewParser() + feed, err := feedparser.Parse(res.Body) if err != nil { return nil, err } return convertItems(feed.Items, f), nil } + +func init() { + defaultClient = &Client{ + httpClient: &http.Client{Timeout: time.Second * 5}, + userAgent: "Yarr/1.0", + } +} diff --git a/server/handlers.go b/server/handlers.go index b3d47d2..c6cc63c 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -227,7 +227,7 @@ func FeedListHandler(rw http.ResponseWriter, req *http.Request) { return } - feed, sources, err := discoverFeed(form.Url, req.Header.Get("user-agent")) + feed, sources, err := discoverFeed(form.Url) if err != nil { handler(req).log.Print(err) writeJSON(rw, map[string]string{"status": "notfound"})