From d7253a60b806e4e03f65452af85371cc3f87c100 Mon Sep 17 00:00:00 2001 From: Nazar Kanaev Date: Sat, 12 Feb 2022 23:42:44 +0000 Subject: [PATCH] strip out invalid xml characters --- src/parser/util.go | 64 +++++++++++++++++++++++++++++- src/parser/util_test.go | 88 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 151 insertions(+), 1 deletion(-) create mode 100644 src/parser/util_test.go diff --git a/src/parser/util.go b/src/parser/util.go index ba32673..6c14439 100644 --- a/src/parser/util.go +++ b/src/parser/util.go @@ -1,10 +1,12 @@ package parser import ( + "bufio" "encoding/xml" "io" "regexp" "strings" + "unicode/utf8" "golang.org/x/net/html/charset" ) @@ -28,8 +30,68 @@ func plain2html(text string) string { } func xmlDecoder(r io.Reader) *xml.Decoder { - decoder := xml.NewDecoder(r) + decoder := xml.NewDecoder(NewSafeXMLReader(r)) decoder.Strict = false decoder.CharsetReader = charset.NewReaderLabel return decoder } + +type safexmlreader struct { + reader *bufio.Reader + buffer []byte + isEOF bool + runebuf []byte +} + +func NewSafeXMLReader(r io.Reader) io.Reader { + return &safexmlreader{ + reader: bufio.NewReader(r), + runebuf: make([]byte, 6), + } +} + +func (xr *safexmlreader) Read(p []byte) (int, error) { + if len(p) == 0 { + return 0, nil + } + + for len(xr.buffer) < cap(p) { + r, _, err := xr.reader.ReadRune() + if err == io.EOF { + xr.isEOF = true + break + } + if err != nil { + return 0, err + } + if isInCharacterRange(r) { + size := utf8.EncodeRune(xr.runebuf, r) + xr.buffer = append(xr.buffer, xr.runebuf[:size]...) + } + } + + if xr.isEOF && len(xr.buffer) == 0 { + return 0, io.EOF + } + + n := cap(p) + if len(xr.buffer) < n { + n = len(xr.buffer) + } + copy(p, xr.buffer[:n]) + xr.buffer = xr.buffer[n:] + return n, nil +} + +// NOTE: copied from "encoding/xml" package +// Decide whether the given rune is in the XML Character Range, per +// the Char production of https://www.xml.com/axml/testaxml.htm, +// Section 2.2 Characters. +func isInCharacterRange(r rune) (inrange bool) { + return r == 0x09 || + r == 0x0A || + r == 0x0D || + r >= 0x20 && r <= 0xD7FF || + r >= 0xE000 && r <= 0xFFFD || + r >= 0x10000 && r <= 0x10FFFF +} diff --git a/src/parser/util_test.go b/src/parser/util_test.go new file mode 100644 index 0000000..3cd9ed4 --- /dev/null +++ b/src/parser/util_test.go @@ -0,0 +1,88 @@ +package parser + +import ( + "bytes" + "io" + "reflect" + "testing" +) + +func TestSafeXMLReader(t *testing.T) { + var f io.Reader + want := []byte("привет мир") + f = bytes.NewReader(want) + f = NewSafeXMLReader(f) + + have, err := io.ReadAll(f) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, have) { + t.Fatalf("invalid output\nwant: %v\nhave: %v", want, have) + } +} + +func TestSafeXMLReaderRemoveUnwantedRunes(t *testing.T) { + var f io.Reader + input := []byte("\aпривет \x0cмир\ufffe\uffff") + want := []byte("привет мир") + f = bytes.NewReader(input) + f = NewSafeXMLReader(f) + + have, err := io.ReadAll(f) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, have) { + t.Fatalf("invalid output\nwant: %v\nhave: %v", want, have) + } +} + +func TestSafeXMLReaderPartial1(t *testing.T) { + var f io.Reader + input := []byte("\aпривет \x0cмир\ufffe\uffff") + want := []byte("привет мир") + f = bytes.NewReader(input) + f = NewSafeXMLReader(f) + + buf := make([]byte, 1) + for i := 0; i < len(want); i++ { + n, err := f.Read(buf) + if err != nil { + t.Fatal(err) + } + if n != 1 { + t.Fatalf("expected 1 byte, got %d", n) + } + if buf[0] != want[i] { + t.Fatalf("invalid char at pos %d\nwant: %v\nhave: %v", i, want[i], buf[0]) + } + } + if x, err := f.Read(buf); err != io.EOF { + t.Fatalf("expected EOF, %v, %v %v", buf, x, err) + } +} + +func TestSafeXMLReaderPartial2(t *testing.T) { + var f io.Reader + input := []byte("привет\a\a\a\a\a") + f = bytes.NewReader(input) + f = NewSafeXMLReader(f) + + buf := make([]byte, 12) + n, err := f.Read(buf) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + if n != 12 { + t.Fatalf("expected 12 bytes") + } + + n, err = f.Read(buf) + if n != 0 { + t.Fatalf("expected 0") + } + if err != io.EOF { + t.Fatalf("expected EOF, got %v", err) + } +}