diff --git a/reader_test.go b/reader_test.go
new file mode 100644
index 0000000..971b3ae
--- /dev/null
+++ b/reader_test.go
@@ -0,0 +1,43 @@
+package validator
+
+import (
+ "bytes"
+ "compress/flate"
+ "io"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// flateIt takes and input string, compresses it using flate, and returns a flate.Reader() of the compressed content
+func flateIt(t *testing.T, input string) io.Reader {
+ t.Helper()
+
+ var zipped bytes.Buffer
+ w, err := flate.NewWriter(&zipped, flate.DefaultCompression)
+ require.NoError(t, err)
+
+ _, err = w.Write([]byte(input))
+ require.NoError(t, err)
+
+ err = w.Close()
+ require.NoError(t, err)
+
+ return flate.NewReader(&zipped)
+}
+
+func TestValidateZippedReader(t *testing.T) {
+ // wrap an innocuous "" XML payload in a flate.Reader :
+ zipped := flateIt(t, ``)
+
+ // Validate should not trigger an error on that Reader :
+ err := Validate(zipped)
+ assert.NoError(t, err, "Should not error on a valid XML document")
+
+ // an invalid document should still error :
+ zipped = flateIt(t, ``)
+
+ err = Validate(zipped)
+ assert.Error(t, err, "Should error on an invalid XML document")
+}
diff --git a/validator.go b/validator.go
index 6931b93..523659d 100644
--- a/validator.go
+++ b/validator.go
@@ -121,9 +121,20 @@ type byteReader struct {
}
func (r *byteReader) ReadByte() (byte, error) {
- p := make([]byte, 1)
- _, err := r.r.Read(p)
- return p[0], err
+ var p [1]byte
+ n, err := r.r.Read(p[:])
+
+ // The doc for the io.ByteReader interface states:
+ // If ReadByte returns an error, no input byte was consumed, and the returned byte value is undefined.
+ // So if a byte is actually extracted from the reader, and we want to return it, we mustn't return the error.
+ if n > 0 {
+ // this byteReader is only used in the context of the Validate() function,
+ // we deliberately choose to completely ignore the error in this case.
+ // return the byte extracted from the reader
+ return p[0], nil
+ }
+
+ return 0, err
}
func (r *byteReader) Read(p []byte) (int, error) {