Skip to content

Commit

Permalink
Fix WSGI middleware not to explode quadratically in the case of a lar…
Browse files Browse the repository at this point in the history
…ger body (#1329)

* Fix WSGI middleware not to explode quadratically in the case of a larger body.

* Clean up more_body and body stream handling, get rid of while-else.

* Update body type in `build_environ`'s tests (although they pass regardless...)

* Address typing issues.
  • Loading branch information
vytas7 authored Feb 16, 2022
1 parent 399f90a commit a694ee4
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 9 deletions.
18 changes: 16 additions & 2 deletions tests/middleware/test_wsgi.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import io
import sys
from typing import List
from typing import AsyncGenerator, List

import httpx
import pytest
Expand Down Expand Up @@ -67,6 +68,19 @@ async def test_wsgi_post() -> None:
assert response.text == '{"example": 123}'


@pytest.mark.asyncio
async def test_wsgi_put_more_body() -> None:
async def generate_body() -> AsyncGenerator[bytes, None]:
for _ in range(1024):
yield b"123456789abcdef\n" * 64

app = WSGIMiddleware(echo_body)
async with httpx.AsyncClient(app=app, base_url="http://testserver") as client:
response = await client.put("/", content=generate_body())
assert response.status_code == 200
assert response.text == "123456789abcdef\n" * 64 * 1024


@pytest.mark.asyncio
async def test_wsgi_exception() -> None:
# Note that we're testing the WSGI app directly here.
Expand Down Expand Up @@ -120,6 +134,6 @@ def test_build_environ_encoding() -> None:
"body": b"",
"more_body": False,
}
environ = build_environ(scope, message, b"")
environ = build_environ(scope, message, io.BytesIO(b""))
assert environ["PATH_INFO"] == "/文".encode("utf8").decode("latin-1")
assert environ["HTTP_KEY"] == "value1,value2"
21 changes: 14 additions & 7 deletions uvicorn/middleware/wsgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
from uvicorn._types import Environ, ExcInfo, StartResponse, WSGIApp


def build_environ(scope: HTTPScope, message: ASGIReceiveEvent, body: bytes) -> Environ:
def build_environ(
scope: HTTPScope, message: ASGIReceiveEvent, body: io.BytesIO
) -> Environ:
"""
Builds a scope and request message into a WSGI environ object.
"""
Expand All @@ -31,7 +33,7 @@ def build_environ(scope: HTTPScope, message: ASGIReceiveEvent, body: bytes) -> E
"SERVER_PROTOCOL": "HTTP/%s" % scope["http_version"],
"wsgi.version": (1, 0),
"wsgi.url_scheme": scope.get("scheme", "http"),
"wsgi.input": io.BytesIO(body),
"wsgi.input": body,
"wsgi.errors": sys.stdout,
"wsgi.multithread": True,
"wsgi.multiprocess": True,
Expand Down Expand Up @@ -105,12 +107,17 @@ async def __call__(
self, receive: ASGIReceiveCallable, send: ASGISendCallable
) -> None:
message: HTTPRequestEvent = await receive() # type: ignore[assignment]
body = message.get("body", b"")
body = io.BytesIO(message.get("body", b""))
more_body = message.get("more_body", False)
while more_body:
body_message: HTTPRequestEvent = await receive() # type: ignore[assignment]
body += body_message.get("body", b"")
more_body = body_message.get("more_body", False)
if more_body:
body.seek(0, io.SEEK_END)
while more_body:
body_message: HTTPRequestEvent = (
await receive() # type: ignore[assignment]
)
body.write(body_message.get("body", b""))
more_body = body_message.get("more_body", False)
body.seek(0)
environ = build_environ(self.scope, message, body)
self.loop = asyncio.get_event_loop()
wsgi = self.loop.run_in_executor(
Expand Down

0 comments on commit a694ee4

Please # to comment.