diff --git a/reader_test.go b/reader_test.go new file mode 100644 index 0000000..971b3ae --- /dev/null +++ b/reader_test.go @@ -0,0 +1,43 @@ +package validator + +import ( + "bytes" + "compress/flate" + "io" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// flateIt takes and input string, compresses it using flate, and returns a flate.Reader() of the compressed content +func flateIt(t *testing.T, input string) io.Reader { + t.Helper() + + var zipped bytes.Buffer + w, err := flate.NewWriter(&zipped, flate.DefaultCompression) + require.NoError(t, err) + + _, err = w.Write([]byte(input)) + require.NoError(t, err) + + err = w.Close() + require.NoError(t, err) + + return flate.NewReader(&zipped) +} + +func TestValidateZippedReader(t *testing.T) { + // wrap an innocuous "" XML payload in a flate.Reader : + zipped := flateIt(t, ``) + + // Validate should not trigger an error on that Reader : + err := Validate(zipped) + assert.NoError(t, err, "Should not error on a valid XML document") + + // an invalid document should still error : + zipped = flateIt(t, ``) + + err = Validate(zipped) + assert.Error(t, err, "Should error on an invalid XML document") +} diff --git a/validator.go b/validator.go index 6931b93..523659d 100644 --- a/validator.go +++ b/validator.go @@ -121,9 +121,20 @@ type byteReader struct { } func (r *byteReader) ReadByte() (byte, error) { - p := make([]byte, 1) - _, err := r.r.Read(p) - return p[0], err + var p [1]byte + n, err := r.r.Read(p[:]) + + // The doc for the io.ByteReader interface states: + // If ReadByte returns an error, no input byte was consumed, and the returned byte value is undefined. + // So if a byte is actually extracted from the reader, and we want to return it, we mustn't return the error. + if n > 0 { + // this byteReader is only used in the context of the Validate() function, + // we deliberately choose to completely ignore the error in this case. + // return the byte extracted from the reader + return p[0], nil + } + + return 0, err } func (r *byteReader) Read(p []byte) (int, error) {