From 52cc8ecbbd7d6e35f595909ddf07d182276d09f9 Mon Sep 17 00:00:00 2001 From: nkanaev Date: Mon, 24 Jan 2022 16:47:32 +0000 Subject: [PATCH] fix encoding --- src/parser/feed.go | 40 ++++++++++++++++++++-------- src/server/routes.go | 5 ++-- src/worker/client.go | 13 --------- src/worker/crawler.go | 62 ++++++++++++++++++++++++++++++------------- 4 files changed, 75 insertions(+), 45 deletions(-) diff --git a/src/parser/feed.go b/src/parser/feed.go index 9e160b3..43c7900 100644 --- a/src/parser/feed.go +++ b/src/parser/feed.go @@ -11,20 +11,23 @@ import ( "time" "github.com/nkanaev/yarr/src/content/htmlutil" + "golang.org/x/net/html/charset" ) var UnknownFormat = errors.New("unknown feed format") type processor func(r io.Reader) (*Feed, error) -func sniff(lookup string) (string, processor) { +func sniff(lookup string) (string, bool, processor) { lookup = strings.TrimSpace(lookup) lookup = strings.TrimLeft(lookup, "\x00\xEF\xBB\xBF\xFE\xFF") - if len(lookup) < 0 { - return "", nil + if len(lookup) == 0 { + return "", false, nil } + var decode bool + switch lookup[0] { case '<': decoder := xmlDecoder(strings.NewReader(lookup)) @@ -33,24 +36,32 @@ func sniff(lookup string) (string, processor) { if token == nil { break } + // check for absence of xml encoding + if el, ok := token.(xml.ProcInst); ok && el.Target == "xml" { + decode = strings.Index(string(el.Inst), "encoding=") == -1 + } if el, ok := token.(xml.StartElement); ok { switch el.Name.Local { case "rss": - return "rss", ParseRSS + return "rss", decode, ParseRSS case "RDF": - return "rdf", ParseRDF + return "rdf", decode, ParseRDF case "feed": - return "atom", ParseAtom + return "atom", decode, ParseAtom } } } case '{': - return "json", ParseJSON + return "json", true, ParseJSON } - return "", nil + return "", false, nil } func Parse(r io.Reader) (*Feed, error) { + return ParseWithEncoding(r, "") +} + +func ParseWithEncoding(r io.Reader, fallbackEncoding string) (*Feed, error) { lookup := make([]byte, 2048) n, err := io.ReadFull(r, lookup) switch { @@ -63,11 +74,18 @@ func Parse(r io.Reader) (*Feed, error) { r = io.MultiReader(bytes.NewReader(lookup), r) } - _, callback := sniff(string(lookup)) + _, decode, callback := sniff(string(lookup)) if callback == nil { return nil, UnknownFormat } + if decode && fallbackEncoding != "" { + r, err = charset.NewReaderLabel(fallbackEncoding, r) + if err != nil { + return nil, err + } + } + feed, err := callback(r) if feed != nil { feed.cleanup() @@ -75,8 +93,8 @@ func Parse(r io.Reader) (*Feed, error) { return feed, err } -func ParseAndFix(r io.Reader, baseURL string) (*Feed, error) { - feed, err := Parse(r) +func ParseAndFix(r io.Reader, baseURL, fallbackEncoding string) (*Feed, error) { + feed, err := ParseWithEncoding(r, fallbackEncoding) if err != nil { return nil, err } diff --git a/src/server/routes.go b/src/server/routes.go index ec32b25..66f85a2 100644 --- a/src/server/routes.go +++ b/src/server/routes.go @@ -457,14 +457,13 @@ func (s *Server) handlePageCrawl(c *router.Context) { return } - res, err := worker.GetHTTP(url) + body, err := worker.GetBody(url) if err != nil { log.Print(err) c.Out.WriteHeader(http.StatusBadRequest) return } - defer res.Body.Close() - content, err := readability.ExtractContent(res.Body) + content, err := readability.ExtractContent(strings.NewReader(body)) if err != nil { log.Print(err) c.Out.WriteHeader(http.StatusNoContent) diff --git a/src/worker/client.go b/src/worker/client.go index 49ab3b3..d7cac5d 100644 --- a/src/worker/client.go +++ b/src/worker/client.go @@ -50,16 +50,3 @@ func init() { userAgent: "Yarr/1.0", } } - -func GetHTTP(url string) (*http.Response, error) { - res, err := client.get(url) - if err != nil { - return nil, err - } - body, err := httpBody(res) - if err != nil { - return nil, err - } - res.Body = body - return res, nil -} diff --git a/src/worker/crawler.go b/src/worker/crawler.go index 87d642a..3720a38 100644 --- a/src/worker/crawler.go +++ b/src/worker/crawler.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "io/ioutil" + "mime" "net/http" "net/url" "strings" @@ -38,18 +39,15 @@ func DiscoverFeed(candidateUrl string) (*DiscoverResult, error) { if res.StatusCode != 200 { return nil, fmt.Errorf("status code %d", res.StatusCode) } + cs := getCharset(res) - body, err := httpBody(res) - if err != nil { - return nil, err - } - content, err := ioutil.ReadAll(body) + body, err := io.ReadAll(res.Body) if err != nil { return nil, err } // Try to feed into parser - feed, err := parser.ParseAndFix(bytes.NewReader(content), candidateUrl) + feed, err := parser.ParseAndFix(bytes.NewReader(body), candidateUrl, cs) if err == nil { result.Feed = feed result.FeedLink = candidateUrl @@ -57,8 +55,16 @@ func DiscoverFeed(candidateUrl string) (*DiscoverResult, error) { } // Possibly an html link. Search for feed links + content := string(body) + if cs != "" { + if r, err := charset.NewReaderLabel(cs, bytes.NewReader(body)); err == nil { + if body, err := io.ReadAll(r); err == nil { + content = string(body) + } + } + } sources := make([]FeedSource, 0) - for url, title := range scraper.FindFeeds(string(content), candidateUrl) { + for url, title := range scraper.FindFeeds(content, candidateUrl) { sources = append(sources, FeedSource{Title: title, Url: url}) } switch { @@ -184,12 +190,7 @@ func listItems(f storage.Feed, db *storage.Storage) ([]storage.Item, error) { return nil, nil } - body, err := httpBody(res) - if err != nil { - return nil, err - } - - feed, err := parser.ParseAndFix(body, f.FeedLink) + feed, err := parser.ParseAndFix(res.Body, f.FeedLink, getCharset(res)) if err != nil { return nil, err } @@ -202,14 +203,39 @@ func listItems(f storage.Feed, db *storage.Storage) ([]storage.Item, error) { return ConvertItems(feed.Items, f), nil } -func httpBody(res *http.Response) (io.ReadCloser, error) { +func getCharset(res *http.Response) string { + contentType := res.Header.Get("Content-Type") + if _, params, err := mime.ParseMediaType(contentType); err == nil { + if cs, ok := params["charset"]; ok { + if e, _ := charset.Lookup(cs); e != nil { + return cs + } + } + } + return "" +} + +func GetBody(url string) (string, error) { + res, err := client.get(url) + if err != nil { + return "", err + } + defer res.Body.Close() + + var r io.Reader + ctype := res.Header.Get("Content-Type") if strings.Contains(ctype, "charset") { - reader, err := charset.NewReader(res.Body, ctype) + r, err = charset.NewReader(res.Body, ctype) if err != nil { - return nil, err + return "", err } - return io.NopCloser(reader), nil + } else { + r = res.Body } - return res.Body, nil + body, err := io.ReadAll(r) + if err != nil { + return "", err + } + return string(body), nil }