feed refactoring

This commit is contained in:
Nazar Kanaev 2021-03-22 21:12:58 +00:00
parent e78c028d20
commit 7ca9415322
2 changed files with 28 additions and 23 deletions

View File

@ -1,6 +1,7 @@
package feed package feed
import ( import (
"bytes"
"encoding/xml" "encoding/xml"
"errors" "errors"
"fmt" "fmt"
@ -12,42 +13,46 @@ var UnknownFormat = errors.New("unknown feed format")
type processor func(r io.Reader) (*Feed, error) type processor func(r io.Reader) (*Feed, error)
func detect(lookup string) (string, processor) { func sniff(lookup string) (string, processor) {
lookup = strings.TrimSpace(lookup) lookup = strings.TrimSpace(lookup)
if lookup[0] == '{' { switch lookup[0] {
return "json", ParseJSON case '<':
} decoder := xml.NewDecoder(strings.NewReader(lookup))
decoder := xml.NewDecoder(strings.NewReader(lookup)) for {
for { token, _ := decoder.Token()
token, _ := decoder.Token() if token == nil {
if token == nil { break
break }
} if el, ok := token.(xml.StartElement); ok {
if el, ok := token.(xml.StartElement); ok { switch el.Name.Local {
switch el.Name.Local { case "rss":
case "rss": return "rss", ParseRSS
return "rss", ParseRSS case "RDF":
case "RDF": return "rss", ParseRDF
return "rss", ParseRDF case "feed":
case "feed": return "atom", ParseAtom
return "atom", ParseAtom }
} }
} }
case '{':
return "json", ParseJSON
} }
return "", nil return "", nil
} }
func Parse(r io.Reader) (*Feed, error) { func Parse(r io.Reader) (*Feed, error) {
var x [1024]byte chunk := make([]byte, 64)
numread, err := r.Read(x[:]) numread, err := r.Read(chunk)
fmt.Println(numread, err) fmt.Println(numread, err)
if err != nil { if err != nil {
return nil, fmt.Errorf("Failed to read: %s", err) return nil, fmt.Errorf("Failed to read: %s", err)
} }
_, callback := detect(string(x[:])) _, callback := sniff(string(chunk))
if callback == nil { if callback == nil {
return nil, UnknownFormat return nil, UnknownFormat
} }
r = io.MultiReader(bytes.NewReader(chunk), r)
return callback(r) return callback(r)
} }

View File

@ -2,7 +2,7 @@ package feed
import "testing" import "testing"
func TestDetect(t *testing.T) { func TestSniff(t *testing.T) {
testcases := [][2]string{ testcases := [][2]string{
{ {
`<?xml version="1.0"?><rdf:RDF xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"></rdf:RDF>`, `<?xml version="1.0"?><rdf:RDF xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"></rdf:RDF>`,
@ -26,7 +26,7 @@ func TestDetect(t *testing.T) {
}, },
} }
for _, testcase := range testcases { for _, testcase := range testcases {
have, _ := detect(testcase[0]) have, _ := sniff(testcase[0])
want := testcase[1] want := testcase[1]
if want != have { if want != have {
t.Log(testcase[0]) t.Log(testcase[0])