Skip to content

Commit

Permalink
fix LimitedStream.read method to work with raw IO streams (#2559)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidism authored Feb 6, 2023
2 parents d554cb7 + 591b115 commit 64f0eb4
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 23 deletions.
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ Unreleased
client. :issue:`2549`
- Fix handling of header extended parameters such that they are no longer quoted.
:issue:`2529`
- ``LimitedStream.read`` works correctly when wrapping a stream that may not return
the requested size in one ``read`` call. :issue:`2558`


Version 2.2.2
Expand Down
70 changes: 55 additions & 15 deletions src/werkzeug/wsgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,37 +928,77 @@ def on_disconnect(self) -> bytes:

raise ClientDisconnected()

def exhaust(self, chunk_size: int = 1024 * 64) -> None:
"""Exhaust the stream. This consumes all the data left until the
limit is reached.
def _exhaust_chunks(self, chunk_size: int = 1024 * 64) -> t.Iterator[bytes]:
"""Exhaust the stream by reading until the limit is reached or the client
disconnects, yielding each chunk.
:param chunk_size: How many bytes to read at a time.
:meta private:
:param chunk_size: the size for a chunk. It will read the chunk
until the stream is exhausted and throw away
the results.
.. versionadded:: 2.2.3
"""
to_read = self.limit - self._pos
chunk = chunk_size

while to_read > 0:
chunk = min(to_read, chunk)
self.read(chunk)
to_read -= chunk
chunk = self.read(min(to_read, chunk_size))
yield chunk
to_read -= len(chunk)

def exhaust(self, chunk_size: int = 1024 * 64) -> None:
"""Exhaust the stream by reading until the limit is reached or the client
disconnects, discarding the data.
:param chunk_size: How many bytes to read at a time.
.. versionchanged:: 2.2.3
Handle case where wrapped stream returns fewer bytes than requested.
"""
for _ in self._exhaust_chunks(chunk_size):
pass

def read(self, size: t.Optional[int] = None) -> bytes:
"""Read `size` bytes or if size is not provided everything is read.
"""Read up to ``size`` bytes from the underlying stream. If size is not
provided, read until the limit.
If the limit is reached, :meth:`on_exhausted` is called, which returns empty
bytes.
:param size: the number of bytes read.
If no bytes are read and the limit is not reached, or if an error occurs during
the read, :meth:`on_disconnect` is called, which raises
:exc:`.ClientDisconnected`.
:param size: The number of bytes to read. ``None``, default, reads until the
limit is reached.
.. versionchanged:: 2.2.3
Handle case where wrapped stream returns fewer bytes than requested.
"""
if self._pos >= self.limit:
return self.on_exhausted()
if size is None or size == -1: # -1 is for consistence with file
size = self.limit

if size is None or size == -1: # -1 is for consistency with file
# Keep reading from the wrapped stream until the limit is reached. Can't
# rely on stream.read(size) because it's not guaranteed to return size.
buf = bytearray()

for chunk in self._exhaust_chunks():
buf.extend(chunk)

return bytes(buf)

to_read = min(self.limit - self._pos, size)

try:
read = self._read(to_read)
except (OSError, ValueError):
return self.on_disconnect()
if to_read and len(read) != to_read:

if to_read and not len(read):
# If no data was read, treat it as a disconnect. As long as some data was
# read, a subsequent call can still return more before reaching the limit.
return self.on_disconnect()

self._pos += len(read)
return read

Expand Down
61 changes: 53 additions & 8 deletions tests/test_wsgi.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from __future__ import annotations

import io
import json
import os
import typing as t

import pytest

Expand Down Expand Up @@ -165,21 +168,63 @@ def test_limited_stream_json_load():


def test_limited_stream_disconnection():
io_ = io.BytesIO(b"A bit of content")

# disconnect detection on out of bytes
stream = wsgi.LimitedStream(io_, 255)
# disconnect because stream returns zero bytes
stream = wsgi.LimitedStream(io.BytesIO(), 255)
with pytest.raises(ClientDisconnected):
stream.read()

# disconnect detection because file close
io_ = io.BytesIO(b"x" * 255)
io_.close()
stream = wsgi.LimitedStream(io_, 255)
# disconnect because stream is closed
data = io.BytesIO(b"x" * 255)
data.close()
stream = wsgi.LimitedStream(data, 255)

with pytest.raises(ClientDisconnected):
stream.read()


def test_limited_stream_read_with_raw_io():
class OneByteStream(t.BinaryIO):
def __init__(self, buf: bytes) -> None:
self.buf = buf
self.pos = 0

def read(self, size: int | None = None) -> bytes:
"""Return one byte at a time regardless of requested size."""

if size is None or size == -1:
raise ValueError("expected read to be called with specific limit")

if size == 0 or len(self.buf) < self.pos:
return b""

b = self.buf[self.pos : self.pos + 1]
self.pos += 1
return b

stream = wsgi.LimitedStream(OneByteStream(b"foo"), 4)
assert stream.read(5) == b"f"
assert stream.read(5) == b"o"
assert stream.read(5) == b"o"

# The stream has fewer bytes (3) than the limit (4), therefore the read returns 0
# bytes before the limit is reached.
with pytest.raises(ClientDisconnected):
stream.read(5)

stream = wsgi.LimitedStream(OneByteStream(b"foo123"), 3)
assert stream.read(5) == b"f"
assert stream.read(5) == b"o"
assert stream.read(5) == b"o"
# The limit was reached, therefore the wrapper is exhausted, not disconnected.
assert stream.read(5) == b""

stream = wsgi.LimitedStream(OneByteStream(b"foo"), 3)
assert stream.read() == b"foo"

stream = wsgi.LimitedStream(OneByteStream(b"foo"), 2)
assert stream.read() == b"fo"


def test_get_host_fallback():
assert (
wsgi.get_host(
Expand Down

0 comments on commit 64f0eb4

Please # to comment.