diff --git a/CHANGES.rst b/CHANGES.rst index 86d3a63ee..0d402a3ea 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -17,6 +17,8 @@ Unreleased client. :issue:`2549` - Fix handling of header extended parameters such that they are no longer quoted. :issue:`2529` +- ``LimitedStream.read`` works correctly when wrapping a stream that may not return + the requested size in one ``read`` call. :issue:`2558` Version 2.2.2 diff --git a/src/werkzeug/wsgi.py b/src/werkzeug/wsgi.py index 54dd60e62..d74430d8b 100644 --- a/src/werkzeug/wsgi.py +++ b/src/werkzeug/wsgi.py @@ -928,37 +928,77 @@ def on_disconnect(self) -> bytes: raise ClientDisconnected() - def exhaust(self, chunk_size: int = 1024 * 64) -> None: - """Exhaust the stream. This consumes all the data left until the - limit is reached. + def _exhaust_chunks(self, chunk_size: int = 1024 * 64) -> t.Iterator[bytes]: + """Exhaust the stream by reading until the limit is reached or the client + disconnects, yielding each chunk. + + :param chunk_size: How many bytes to read at a time. + + :meta private: - :param chunk_size: the size for a chunk. It will read the chunk - until the stream is exhausted and throw away - the results. + .. versionadded:: 2.2.3 """ to_read = self.limit - self._pos - chunk = chunk_size + while to_read > 0: - chunk = min(to_read, chunk) - self.read(chunk) - to_read -= chunk + chunk = self.read(min(to_read, chunk_size)) + yield chunk + to_read -= len(chunk) + + def exhaust(self, chunk_size: int = 1024 * 64) -> None: + """Exhaust the stream by reading until the limit is reached or the client + disconnects, discarding the data. + + :param chunk_size: How many bytes to read at a time. + + .. versionchanged:: 2.2.3 + Handle case where wrapped stream returns fewer bytes than requested. + """ + for _ in self._exhaust_chunks(chunk_size): + pass def read(self, size: t.Optional[int] = None) -> bytes: - """Read `size` bytes or if size is not provided everything is read. + """Read up to ``size`` bytes from the underlying stream. If size is not + provided, read until the limit. + + If the limit is reached, :meth:`on_exhausted` is called, which returns empty + bytes. - :param size: the number of bytes read. + If no bytes are read and the limit is not reached, or if an error occurs during + the read, :meth:`on_disconnect` is called, which raises + :exc:`.ClientDisconnected`. + + :param size: The number of bytes to read. ``None``, default, reads until the + limit is reached. + + .. versionchanged:: 2.2.3 + Handle case where wrapped stream returns fewer bytes than requested. """ if self._pos >= self.limit: return self.on_exhausted() - if size is None or size == -1: # -1 is for consistence with file - size = self.limit + + if size is None or size == -1: # -1 is for consistency with file + # Keep reading from the wrapped stream until the limit is reached. Can't + # rely on stream.read(size) because it's not guaranteed to return size. + buf = bytearray() + + for chunk in self._exhaust_chunks(): + buf.extend(chunk) + + return bytes(buf) + to_read = min(self.limit - self._pos, size) + try: read = self._read(to_read) except (OSError, ValueError): return self.on_disconnect() - if to_read and len(read) != to_read: + + if to_read and not len(read): + # If no data was read, treat it as a disconnect. As long as some data was + # read, a subsequent call can still return more before reaching the limit. return self.on_disconnect() + self._pos += len(read) return read diff --git a/tests/test_wsgi.py b/tests/test_wsgi.py index b0f71bcdf..cdc151d1f 100644 --- a/tests/test_wsgi.py +++ b/tests/test_wsgi.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import io import json import os +import typing as t import pytest @@ -165,21 +168,63 @@ def test_limited_stream_json_load(): def test_limited_stream_disconnection(): - io_ = io.BytesIO(b"A bit of content") - - # disconnect detection on out of bytes - stream = wsgi.LimitedStream(io_, 255) + # disconnect because stream returns zero bytes + stream = wsgi.LimitedStream(io.BytesIO(), 255) with pytest.raises(ClientDisconnected): stream.read() - # disconnect detection because file close - io_ = io.BytesIO(b"x" * 255) - io_.close() - stream = wsgi.LimitedStream(io_, 255) + # disconnect because stream is closed + data = io.BytesIO(b"x" * 255) + data.close() + stream = wsgi.LimitedStream(data, 255) + with pytest.raises(ClientDisconnected): stream.read() +def test_limited_stream_read_with_raw_io(): + class OneByteStream(t.BinaryIO): + def __init__(self, buf: bytes) -> None: + self.buf = buf + self.pos = 0 + + def read(self, size: int | None = None) -> bytes: + """Return one byte at a time regardless of requested size.""" + + if size is None or size == -1: + raise ValueError("expected read to be called with specific limit") + + if size == 0 or len(self.buf) < self.pos: + return b"" + + b = self.buf[self.pos : self.pos + 1] + self.pos += 1 + return b + + stream = wsgi.LimitedStream(OneByteStream(b"foo"), 4) + assert stream.read(5) == b"f" + assert stream.read(5) == b"o" + assert stream.read(5) == b"o" + + # The stream has fewer bytes (3) than the limit (4), therefore the read returns 0 + # bytes before the limit is reached. + with pytest.raises(ClientDisconnected): + stream.read(5) + + stream = wsgi.LimitedStream(OneByteStream(b"foo123"), 3) + assert stream.read(5) == b"f" + assert stream.read(5) == b"o" + assert stream.read(5) == b"o" + # The limit was reached, therefore the wrapper is exhausted, not disconnected. + assert stream.read(5) == b"" + + stream = wsgi.LimitedStream(OneByteStream(b"foo"), 3) + assert stream.read() == b"foo" + + stream = wsgi.LimitedStream(OneByteStream(b"foo"), 2) + assert stream.read() == b"fo" + + def test_get_host_fallback(): assert ( wsgi.get_host(