From 53c1473b5ff7502816a9a339ffc90731bb0c2138 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Janek=20Nouvertn=C3=A9?= Date: Wed, 20 Nov 2024 16:09:41 +0100 Subject: [PATCH] Merge commit from fork Co-authored-by: Jacob Coffee Co-authored-by: guacs <126393040+guacs@users.noreply.github.com> --- docs/usage/requests.rst | 49 +++++ litestar/app.py | 5 + litestar/config/app.py | 3 + litestar/connection/request.py | 66 +++++- litestar/controller.py | 10 + litestar/exceptions/http_exceptions.py | 6 + litestar/handlers/http_handlers/base.py | 26 +++ litestar/handlers/http_handlers/decorators.py | 12 ++ litestar/router.py | 5 + tests/conftest.py | 1 + .../test_connection_caching.py | 12 +- tests/unit/test_connection/test_request.py | 193 ++++++++++++------ .../test_http_handlers/test_resolution.py | 66 ++++++ 13 files changed, 387 insertions(+), 67 deletions(-) create mode 100644 tests/unit/test_handlers/test_http_handlers/test_resolution.py diff --git a/docs/usage/requests.rst b/docs/usage/requests.rst index ac749f7d01..3dfef97435 100644 --- a/docs/usage/requests.rst +++ b/docs/usage/requests.rst @@ -160,3 +160,52 @@ The example below illustrates how to implement custom request class for the whol class on multiple layers, the layer closest to the route handler will take precedence. You can read more about this in the :ref:`usage/applications:layered architecture` section + + +Limits +------- + +Body size +^^^^^^^^^^ + +A limit for the allowed request body size can be set on all layers via the +``request_max_body_size`` parameter and defaults to 10MB. If a request body exceeds this +limit, a ``413 - Request Entity Too Large`` +response will be returned. This limit applies to all methods of consuming the request +body, including requesting it via the ``body`` parameter in a route handler and +consuming it through a manually constructed :class:`~litestar.connection.Request` +instance, e.g. in a middleware. + +To disable this limit for a specific handler / router / controller, it can be set to +:obj:`None`. + +.. danger:: + Setting ``request_max_body_size=None`` is strongly discouraged as it exposes the + application to a denial of service (DoS) attack by sending arbitrarily large + request bodies to the affected endpoint. Because Litestar has to read the whole body + to perform certain actions, such as parsing JSON, it will fill up all the available + memory / swap until the application / server crashes, should no outside limits be + imposed. + + This is generally only recommended in environments where the application is running + behind a reverse proxy such as NGINX, where a size limit is already set. + + +.. danger:: + Since ``request_max_body_size`` is handled on a per-request basis, it won't affect + middlewares or ASGI handlers when they try to access the request body via the raw + ASGI events. To avoid this, middlewares and ASGI handlers should construct a + :class:`~litestar.connection.Request` instance and use the regular + :meth:`~litestar.connection.Request.stream` / + :meth:`~litestar.connection.Request.body` or content-appropriate method to consume + the request body in a safe manner. + + +.. tip:: + For requests that define a ``Content-Length`` header, Litestar will not attempt to + read the request body should the header value exceed the ``request_max_body_size``. + + If the header value is within the allowed bounds, Litestar will verify during the + streaming of the request body that it does not exceed the size specified in the + header. Should the request exceed this size, it will abort the request with a + ``400 - Bad Request``. diff --git a/litestar/app.py b/litestar/app.py index b09e9b6abd..62983388ec 100644 --- a/litestar/app.py +++ b/litestar/app.py @@ -202,6 +202,7 @@ def __init__( path: str | None = None, plugins: Sequence[PluginProtocol] | None = None, request_class: type[Request] | None = None, + request_max_body_size: int | None = 10_000_000, response_cache_config: ResponseCacheConfig | None = None, response_class: type[Response] | None = None, response_cookies: ResponseCookies | None = None, @@ -286,6 +287,8 @@ def __init__( pdb_on_exception: Drop into the PDB when an exception occurs. plugins: Sequence of plugins. request_class: An optional subclass of :class:`Request <.connection.Request>` to use for http connections. + request_max_body_size: Maximum allowed size of the request body in bytes. If this size is exceeded, a + '413 - Request Entity Too Large' error response is returned. response_class: A custom subclass of :class:`Response <.response.Response>` to be used as the app's default response. response_cookies: A sequence of :class:`Cookie <.datastructures.Cookie>`. @@ -361,6 +364,7 @@ def __init__( pdb_on_exception=pdb_on_exception, plugins=self._get_default_plugins(list(plugins or [])), request_class=request_class, + request_max_body_size=request_max_body_size, response_cache_config=response_cache_config or ResponseCacheConfig(), response_class=response_class, response_cookies=response_cookies or [], @@ -464,6 +468,7 @@ def __init__( parameters=config.parameters, path=config.path, request_class=self.request_class, + request_max_body_size=request_max_body_size, response_class=config.response_class, response_cookies=config.response_cookies, response_headers=config.response_headers, diff --git a/litestar/config/app.py b/litestar/config/app.py index aef812ecca..a314b9e9e3 100644 --- a/litestar/config/app.py +++ b/litestar/config/app.py @@ -163,6 +163,9 @@ class AppConfig: """List of :class:`SerializationPluginProtocol <.plugins.SerializationPluginProtocol>`.""" request_class: type[Request] | None = field(default=None) """An optional subclass of :class:`Request <.connection.Request>` to use for http connections.""" + request_max_body_size: int | None | EmptyType = Empty + """Maximum allowed size of the request body in bytes. If this size is exceeded, a '413 - Request Entity Too Large' + error response is returned.""" response_class: type[Response] | None = field(default=None) """A custom subclass of :class:`Response <.response.Response>` to be used as the app's default response.""" response_cookies: ResponseCookies = field(default_factory=list) diff --git a/litestar/connection/request.py b/litestar/connection/request.py index 23c60f0b3c..e76054b042 100644 --- a/litestar/connection/request.py +++ b/litestar/connection/request.py @@ -1,7 +1,8 @@ from __future__ import annotations +import math import warnings -from typing import TYPE_CHECKING, Any, AsyncGenerator, Generic +from typing import TYPE_CHECKING, Any, AsyncGenerator, Generic, cast from litestar._multipart import parse_content_header, parse_multipart_form from litestar._parsers import parse_url_encoded_form_data @@ -17,12 +18,14 @@ from litestar.datastructures.multi_dicts import FormMultiDict from litestar.enums import ASGIExtension, RequestEncodingType from litestar.exceptions import ( + ClientException, InternalServerException, LitestarException, LitestarWarning, ) +from litestar.exceptions.http_exceptions import RequestEntityTooLarge from litestar.serialization import decode_json, decode_msgpack -from litestar.types import Empty +from litestar.types import Empty, HTTPReceiveMessage __all__ = ("Request",) @@ -52,6 +55,7 @@ class Request(Generic[UserT, AuthT, StateT], ASGIConnection["HTTPRouteHandler", "_msgpack", "_content_type", "_accept", + "_content_length", "is_connected", "supports_push_promise", ) @@ -79,6 +83,7 @@ def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = self._msgpack: Any = Empty self._content_type: tuple[str, dict[str, str]] | EmptyType = Empty self._accept: Accept | EmptyType = Empty + self._content_length: int | None | EmptyType = Empty self.supports_push_promise = ASGIExtension.SERVER_PUSH in self._server_extensions @property @@ -152,6 +157,21 @@ async def msgpack(self) -> Any: ) return self._msgpack + @property + def content_length(self) -> int | None: + cached_content_length = self._content_length + if cached_content_length is not Empty: + return cached_content_length + + content_length_header = self.headers.get("content-length") + try: + content_length = self._content_length = ( + int(content_length_header) if content_length_header is not None else None + ) + except ValueError: + raise ClientException(f"Invalid content-length: {content_length_header!r}") from None + return content_length + async def stream(self) -> AsyncGenerator[bytes, None]: """Return an async generator that streams chunks of bytes. @@ -164,10 +184,46 @@ async def stream(self) -> AsyncGenerator[bytes, None]: if self._body is Empty: if not self.is_connected: raise InternalServerException("stream consumed") - while event := await self.receive(): + + announced_content_length = self.content_length + # setting this to 'math.inf' as a micro-optimisation; Comparing against a + # float is slightly faster than checking if a value is 'None' and then + # comparing it to an int. since we expect a limit to be set most of the + # time, this is a bit more efficient + max_content_length = self.route_handler.resolve_request_max_body_size() or math.inf + + # if the 'content-length' header is set, and exceeds the limit, we can bail + # out early before reading anything + if announced_content_length is not None and announced_content_length > max_content_length: + raise RequestEntityTooLarge + + total_bytes_streamed: int = 0 + while event := cast("HTTPReceiveMessage", await self.receive()): if event["type"] == "http.request": - if event["body"]: - yield event["body"] + body = event["body"] + if body: + total_bytes_streamed += len(body) + + # if a 'content-length' header was set, check if we have + # received more bytes than specified. in most cases this should + # be caught before it hits the application layer and an ASGI + # server (e.g. uvicorn) will not allow this, but since it's not + # forbidden according to the HTTP or ASGI spec, we err on the + # side of caution and still perform this check. + # + # uvicorn documented behaviour for this case: + # https://github.com/encode/uvicorn/blob/fe3910083e3990695bc19c2ef671dd447262ae18/docs/server-behavior.md?plain=1#L11 + if announced_content_length: + if total_bytes_streamed > announced_content_length: + raise ClientException("Malformed request") + + # we don't have a 'content-length' header, likely a chunked + # transfer. we don't really care and simply check if we have + # received more bytes than allowed + elif total_bytes_streamed > max_content_length: + raise RequestEntityTooLarge + + yield body if not event.get("more_body", False): break diff --git a/litestar/controller.py b/litestar/controller.py index 3893acdf98..7786e023f9 100644 --- a/litestar/controller.py +++ b/litestar/controller.py @@ -64,6 +64,7 @@ class Controller: "parameters", "path", "request_class", + "request_max_body_size", "response_class", "response_cookies", "response_headers", @@ -136,6 +137,11 @@ class Controller: """A custom subclass of :class:`Request <.connection.Request>` to be used as the default request for all route handlers under the controller. """ + request_max_body_size: int | None | EmptyType + """ + Maximum allowed size of the request body in bytes. If this size is exceeded, a '413 - Request Entity Too Large' + error response is returned.""" + response_class: type[Response] | None """A custom subclass of :class:`Response <.response.Response>` to be used as the default response for all route handlers under the controller. @@ -191,6 +197,9 @@ def __init__(self, owner: Router) -> None: if not hasattr(self, "include_in_schema"): self.include_in_schema = Empty + if not hasattr(self, "request_max_body_size"): + self.request_max_body_size = Empty + self.signature_namespace = add_types_to_signature_namespace( getattr(self, "signature_types", []), getattr(self, "signature_namespace", {}) ) @@ -235,6 +244,7 @@ def as_router(self) -> Router: type_encoders=self.type_encoders, type_decoders=self.type_decoders, websocket_class=self.websocket_class, + request_max_body_size=self.request_max_body_size, ) router.owner = self.owner return router diff --git a/litestar/exceptions/http_exceptions.py b/litestar/exceptions/http_exceptions.py index bd384c363b..f3a34174eb 100644 --- a/litestar/exceptions/http_exceptions.py +++ b/litestar/exceptions/http_exceptions.py @@ -10,6 +10,7 @@ HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND, HTTP_405_METHOD_NOT_ALLOWED, + HTTP_413_REQUEST_ENTITY_TOO_LARGE, HTTP_429_TOO_MANY_REQUESTS, HTTP_500_INTERNAL_SERVER_ERROR, HTTP_503_SERVICE_UNAVAILABLE, @@ -119,6 +120,11 @@ class MethodNotAllowedException(ClientException): status_code = HTTP_405_METHOD_NOT_ALLOWED +class RequestEntityTooLarge(ClientException): + status_code = HTTP_413_REQUEST_ENTITY_TOO_LARGE + detail = "Request Entity Too Large" + + class TooManyRequestsException(ClientException): """Request limits have been exceeded.""" diff --git a/litestar/handlers/http_handlers/base.py b/litestar/handlers/http_handlers/base.py index 7796e40d63..8445d6c6ab 100644 --- a/litestar/handlers/http_handlers/base.py +++ b/litestar/handlers/http_handlers/base.py @@ -82,6 +82,7 @@ class HTTPRouteHandler(BaseRouteHandler): "_resolved_request_class", "_resolved_tags", "_resolved_security", + "_resolved_request_max_body_size", "after_request", "after_response", "background", @@ -113,6 +114,7 @@ class HTTPRouteHandler(BaseRouteHandler): "sync_to_thread", "tags", "template_name", + "request_max_body_size", ) has_sync_callable: bool @@ -139,6 +141,7 @@ def __init__( name: str | None = None, opt: Mapping[str, Any] | None = None, request_class: type[Request] | None = None, + request_max_body_size: int | None | EmptyType = Empty, response_class: type[Response] | None = None, response_cookies: ResponseCookies | None = None, response_headers: ResponseHeaders | None = None, @@ -204,6 +207,8 @@ def __init__( :class:`ASGI Scope <.types.Scope>`. request_class: A custom subclass of :class:`Request <.connection.Request>` to be used as route handler's default request. + request_max_body_size: Maximum allowed size of the request body in bytes. If this size is exceeded, + a '413 - Request Entity Too Large' error response is returned. response_class: A custom subclass of :class:`Response <.response.Response>` to be used as route handler's default response. response_cookies: A sequence of :class:`Cookie <.datastructures.Cookie>` instances. @@ -272,6 +277,7 @@ def __init__( self.response_class = response_class self.response_cookies: Sequence[Cookie] | None = narrow_response_cookies(response_cookies) self.response_headers: Sequence[ResponseHeader] | None = narrow_response_headers(response_headers) + self.request_max_body_size = request_max_body_size self.sync_to_thread = sync_to_thread # OpenAPI related attributes @@ -297,6 +303,7 @@ def __init__( self._resolved_request_class: type[Request] | EmptyType = Empty self._resolved_security: list[SecurityRequirement] | EmptyType = Empty self._resolved_tags: list[str] | EmptyType = Empty + self._resolved_request_max_body_size: int | EmptyType | None = Empty def __call__(self, fn: AnyCallable) -> HTTPRouteHandler: """Replace a function with itself.""" @@ -473,6 +480,25 @@ def resolve_tags(self) -> list[str]: return self._resolved_tags + def resolve_request_max_body_size(self) -> int | None: + if (resolved_limits := self._resolved_request_max_body_size) is not Empty: + return resolved_limits + + max_body_size = self._resolved_request_max_body_size = next( # pyright: ignore + ( + max_body_size + for layer in reversed(self.ownership_layers) + if (max_body_size := layer.request_max_body_size) is not Empty + ), + Empty, + ) + if max_body_size is Empty: + raise ImproperlyConfiguredException( + "'request_max_body_size' set to 'Empty' on all layers. To omit a limit, " + "set 'request_max_body_size=None'" + ) + return max_body_size + def get_response_handler(self, is_response_type_data: bool = False) -> Callable[[Any], Awaitable[ASGIApp]]: """Resolve the response_handler function for the route handler. diff --git a/litestar/handlers/http_handlers/decorators.py b/litestar/handlers/http_handlers/decorators.py index 593a1a7d19..fe0b0cd56e 100644 --- a/litestar/handlers/http_handlers/decorators.py +++ b/litestar/handlers/http_handlers/decorators.py @@ -628,6 +628,7 @@ def __init__( name: str | None = None, opt: Mapping[str, Any] | None = None, request_class: type[Request] | None = None, + request_max_body_size: int | None | EmptyType = Empty, response_class: type[Response] | None = None, response_cookies: ResponseCookies | None = None, response_headers: ResponseHeaders | None = None, @@ -692,6 +693,8 @@ def __init__( wherever you have access to :class:`Request <.connection.Request>` or :class:`ASGI Scope <.types.Scope>`. request_class: A custom subclass of :class:`Request <.connection.Request>` to be used as route handler's default request. + request_max_body_size: Maximum allowed size of the request body in bytes. If this size is exceeded, + a '413 - Request Entity Too Large' error response is returned. response_class: A custom subclass of :class:`Response <.response.Response>` to be used as route handler's default response. response_cookies: A sequence of :class:`Cookie <.datastructures.Cookie>` instances. @@ -755,6 +758,7 @@ def __init__( path=path, raises=raises, request_class=request_class, + request_max_body_size=request_max_body_size, response_class=response_class, response_cookies=response_cookies, response_description=response_description, @@ -803,6 +807,7 @@ def __init__( name: str | None = None, opt: Mapping[str, Any] | None = None, request_class: type[Request] | None = None, + request_max_body_size: int | None | EmptyType = Empty, response_class: type[Response] | None = None, response_cookies: ResponseCookies | None = None, response_headers: ResponseHeaders | None = None, @@ -867,6 +872,8 @@ def __init__( wherever you have access to :class:`Request <.connection.Request>` or :class:`ASGI Scope <.types.Scope>`. request_class: A custom subclass of :class:`Request <.connection.Request>` to be used as route handler's default request. + request_max_body_size: Maximum allowed size of the request body in bytes. If this size is exceeded, + a '413 - Request Entity Too Large' error response is returned. response_class: A custom subclass of :class:`Response <.response.Response>` to be used as route handler's default response. response_cookies: A sequence of :class:`Cookie <.datastructures.Cookie>` instances. @@ -930,6 +937,7 @@ def __init__( path=path, raises=raises, request_class=request_class, + request_max_body_size=request_max_body_size, response_class=response_class, response_cookies=response_cookies, response_description=response_description, @@ -978,6 +986,7 @@ def __init__( name: str | None = None, opt: Mapping[str, Any] | None = None, request_class: type[Request] | None = None, + request_max_body_size: int | None | EmptyType = Empty, response_class: type[Response] | None = None, response_cookies: ResponseCookies | None = None, response_headers: ResponseHeaders | None = None, @@ -1042,6 +1051,8 @@ def __init__( wherever you have access to :class:`Request <.connection.Request>` or :class:`ASGI Scope <.types.Scope>`. request_class: A custom subclass of :class:`Request <.connection.Request>` to be used as route handler's default request. + request_max_body_size: Maximum allowed size of the request body in bytes. If this size is exceeded, + a '413 - Request Entity Too Large' error response is returned. response_class: A custom subclass of :class:`Response <.response.Response>` to be used as route handler's default response. response_cookies: A sequence of :class:`Cookie <.datastructures.Cookie>` instances. @@ -1105,6 +1116,7 @@ def __init__( path=path, raises=raises, request_class=request_class, + request_max_body_size=request_max_body_size, response_class=response_class, response_cookies=response_cookies, response_description=response_description, diff --git a/litestar/router.py b/litestar/router.py index 88ac0fd567..6b9ca1c953 100644 --- a/litestar/router.py +++ b/litestar/router.py @@ -68,6 +68,7 @@ class Router: "path", "registered_route_handler_ids", "request_class", + "request_max_body_size", "response_class", "response_cookies", "response_headers", @@ -111,6 +112,7 @@ def __init__( type_decoders: TypeDecodersSequence | None = None, type_encoders: TypeEncodersMap | None = None, websocket_class: type[WebSocket] | None = None, + request_max_body_size: int | None | EmptyType = Empty, ) -> None: """Initialize a ``Router``. @@ -143,6 +145,8 @@ def __init__( with the router instance. request_class: A custom subclass of :class:`Request <.connection.Request>` to be used as the default for all route handlers, controllers and other routers associated with the router instance. + request_max_body_size: Maximum allowed size of the request body in bytes. If this size is exceeded, + a '413 - Request Entity Too Large" error response is returned. response_class: A custom subclass of :class:`Response <.response.Response>` to be used as the default for all route handlers, controllers and other routers associated with the router instance. response_cookies: A sequence of :class:`Cookie <.datastructures.Cookie>` instances. @@ -197,6 +201,7 @@ def __init__( self.type_encoders = dict(type_encoders) if type_encoders is not None else None self.type_decoders = list(type_decoders) if type_decoders is not None else None self.websocket_class = websocket_class + self.request_max_body_size = request_max_body_size for route_handler in route_handlers or []: self.register(value=route_handler) diff --git a/tests/conftest.py b/tests/conftest.py index ca3032f680..e4ae950d22 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -216,6 +216,7 @@ def inner( "route_handler": route_handler, "user": user, "session": session, + "headers": [], **kwargs, } return cast("Scope", scope) diff --git a/tests/unit/test_connection/test_connection_caching.py b/tests/unit/test_connection/test_connection_caching.py index 43c2fe9865..acbf46e706 100644 --- a/tests/unit/test_connection/test_connection_caching.py +++ b/tests/unit/test_connection/test_connection_caching.py @@ -5,7 +5,7 @@ import pytest -from litestar import Request +from litestar import Request, post from litestar.testing import RequestFactory from litestar.types import Empty, HTTPReceiveMessage, Scope from litestar.utils.scope.state import ScopeState @@ -17,11 +17,15 @@ async def test_multiple_request_object_data_caching(create_scope: Callable[..., https://github.com/litestar-org/litestar/issues/2727 """ + @post("/", request_max_body_size=None) + async def handler() -> None: + pass + async def test_receive() -> HTTPReceiveMessage: mock() return {"type": "http.request", "body": b"abc", "more_body": False} - scope = create_scope() + scope = create_scope(route_handler=handler) request_1 = Request[Any, Any, Any](scope, test_receive) request_2 = Request[Any, Any, Any](scope, test_receive) assert (await request_1.body()) == b"abc" @@ -121,6 +125,8 @@ def check_get_mock() -> None: get_mock.assert_has_calls([call(state_key), call("headers")]) elif state_key == "form": get_mock.assert_has_calls([call(state_key), call("content_type")]) + elif state_key == "body": + get_mock.assert_has_calls([call(state_key), call("headers")]) else: get_mock.assert_called_once_with(state_key) @@ -136,6 +142,8 @@ def check_set_mock() -> None: set_mock.assert_has_calls([call("content_type", ANY), call(state_key, ANY)]) elif state_key in {"accept", "cookies", "content_type"}: set_mock.assert_has_calls([call("headers", ANY), call(state_key, ANY)]) + elif state_key == "body": + set_mock.assert_has_calls([call("headers", ANY), call(state_key, ANY)]) else: set_mock.assert_called_once_with(state_key, ANY) diff --git a/tests/unit/test_connection/test_request.py b/tests/unit/test_connection/test_request.py index ec532852d7..930df71f34 100644 --- a/tests/unit/test_connection/test_request.py +++ b/tests/unit/test_connection/test_request.py @@ -11,7 +11,7 @@ import pytest -from litestar import MediaType, Request, asgi, get, post +from litestar import MediaType, Request, get, post from litestar.connection.base import AuthT, StateT, UserT, empty_send from litestar.datastructures import Address, Cookie, State from litestar.exceptions import ( @@ -24,6 +24,7 @@ from litestar.response.base import ASGIResponse from litestar.serialization import encode_json, encode_msgpack from litestar.static_files.config import StaticFilesConfig +from litestar.status_codes import HTTP_400_BAD_REQUEST, HTTP_413_REQUEST_ENTITY_TOO_LARGE from litestar.testing import TestClient, create_test_client if TYPE_CHECKING: @@ -32,7 +33,7 @@ from litestar.types import ASGIApp, Receive, Scope, Send -@get("/", sync_to_thread=False) +@get("/", sync_to_thread=False, request_max_body_size=None) def _route_handler() -> None: pass @@ -230,56 +231,51 @@ def test_request_client( def test_request_body() -> None: - async def app(scope: Scope, receive: Receive, send: Send) -> None: - request = Request[Any, Any, State](scope, receive) + @post("/") + async def handler(request: Request) -> bytes: body = await request.body() - response = ASGIResponse(body=encode_json({"body": body.decode()})) - await response(scope, receive, send) - - client = TestClient(app) + return encode_json({"body": body.decode()}) - response = client.get("/") - assert response.json() == {"body": ""} + with create_test_client([handler]) as client: + response = client.post("/") + assert response.json() == {"body": ""} - response = client.post("/", json={"a": "123"}) - assert response.json() == {"body": '{"a": "123"}'} + response = client.post("/", json={"a": "123"}) + assert response.json() == {"body": '{"a": "123"}'} - response = client.post("/", content="abc") - assert response.json() == {"body": "abc"} + response = client.post("/", content="abc") + assert response.json() == {"body": "abc"} def test_request_stream() -> None: - async def app(scope: Scope, receive: Receive, send: Send) -> None: - request = Request[Any, Any, State](scope, receive) + @post("/") + async def handler(request: Request) -> bytes: body = b"" async for chunk in request.stream(): body += chunk - response = ASGIResponse(body=encode_json({"body": body.decode()})) - await response(scope, receive, send) + return encode_json({"body": body.decode()}) - client = TestClient(app) + with create_test_client([handler]) as client: + response = client.post("/") + assert response.json() == {"body": ""} - response = client.get("/") - assert response.json() == {"body": ""} - - response = client.post("/", json={"a": "123"}) - assert response.json() == {"body": '{"a": "123"}'} + response = client.post("/", json={"a": "123"}) + assert response.json() == {"body": '{"a": "123"}'} - response = client.post("/", content="abc") - assert response.json() == {"body": "abc"} + response = client.post("/", content="abc") + assert response.json() == {"body": "abc"} def test_request_form_urlencoded() -> None: - async def app(scope: Scope, receive: Receive, send: Send) -> None: - request = Request[Any, Any, State](scope, receive) + @post("/") + async def handler(request: Request) -> bytes: form = await request.form() - response = ASGIResponse(body=encode_json({"form": dict(form)})) - await response(scope, receive, send) - client = TestClient(app) + return encode_json({"form": dict(form)}) - response = client.post("/", data={"abc": "123 @"}) - assert response.json() == {"form": {"abc": "123 @"}} + with create_test_client([handler]) as client: + response = client.post("/", data={"abc": "123 @"}) + assert response.json() == {"form": {"abc": "123 @"}} def test_request_form_urlencoded_multi_keys() -> None: @@ -301,19 +297,17 @@ async def handler(request: Request) -> int: def test_request_body_then_stream() -> None: - async def app(scope: Any, receive: Receive, send: Send) -> None: - request = Request[Any, Any, State](scope, receive) + @post("/") + async def handler(request: Request) -> bytes: body = await request.body() chunks = b"" async for chunk in request.stream(): chunks += chunk - response = ASGIResponse(body=encode_json({"body": body.decode(), "stream": chunks.decode()})) - await response(scope, receive, send) - - client = TestClient(app) + return encode_json({"body": body.decode(), "stream": chunks.decode()}) - response = client.post("/", content="abc") - assert response.json() == {"body": "abc", "stream": "abc"} + with create_test_client([handler]) as client: + response = client.post("/", content="abc") + assert response.json() == {"body": "abc", "stream": "abc"} def test_request_stream_then_body() -> None: @@ -329,19 +323,27 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: response = ASGIResponse(body=encode_json({"body": body.decode(), "stream": chunks.decode()})) await response(scope, receive, send) - client = TestClient(app) + @post("/") + async def handler(request: Request) -> bytes: + chunks = b"" + async for chunk in request.stream(): + chunks += chunk + try: + body = await request.body() + except InternalServerException: + body = b"" + return encode_json({"body": body.decode(), "stream": chunks.decode()}) - response = client.post("/", content="abc") - assert response.json() == {"body": "", "stream": "abc"} + with create_test_client([handler]) as client: + response = client.post("/", content="abc") + assert response.json() == {"body": "", "stream": "abc"} def test_request_json() -> None: - @asgi("/") - async def handler(scope: Scope, receive: Receive, send: Send) -> None: - request = Request[Any, Any, State](scope, receive) + @post("/") + async def handler(request: Request) -> bytes: data = await request.json() - response = ASGIResponse(body=encode_json({"json": data})) - await response(scope, receive, send) + return encode_json({"json": data}) with create_test_client(handler) as client: response = client.post("/", json={"a": "123"}) @@ -361,10 +363,11 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: assert response.text == "/he/llo, b'/he%2Fllo'" -def test_request_without_setting_receive() -> None: +def test_request_without_setting_receive(create_scope: Callable[..., Scope]) -> None: """If Request is instantiated without the 'receive' channel, then .body() is not available.""" async def app(scope: Scope, receive: Receive, send: Send) -> None: + scope.update(create_scope(route_handler=_route_handler)) # type: ignore[typeddict-item] request = Request[Any, Any, State](scope) try: data = await request.json() @@ -431,20 +434,19 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: def test_chunked_encoding() -> None: - async def app(scope: Scope, receive: Receive, send: Send) -> None: - request = Request[Any, Any, State](scope, receive) + @post("/") + async def handler(request: Request) -> bytes: body = await request.body() - response = ASGIResponse(body=encode_json({"body": body.decode()})) - await response(scope, receive, send) + return encode_json({"body": body.decode()}) - client = TestClient(app) + with create_test_client([handler]) as client: - def post_body() -> Generator[bytes, None, None]: - yield b"foo" - yield b"bar" + def post_body() -> Generator[bytes, None, None]: + yield b"foo" + yield b"bar" - response = client.post("/", content=post_body()) - assert response.json() == {"body": "foobar"} + response = client.post("/", content=post_body()) + assert response.json() == {"body": "foobar"} def test_request_send_push_promise() -> None: @@ -548,3 +550,74 @@ async def get_state(request: Request[Any, Any, State]) -> dict[str, str]: ) as client: response = client.get("/") assert response.json() == {"state": 2} + + +def test_request_body_exceeds_content_length() -> None: + @post("/") + def handler(body: bytes) -> None: + pass + + with create_test_client([handler]) as client: + response = client.post("/", headers={"content-length": "1"}, content=b"ab") + assert response.status_code == HTTP_400_BAD_REQUEST + assert response.json() == {"status_code": 400, "detail": "Malformed request"} + + +def test_request_body_exceeds_max_request_body_size() -> None: + @post("/one", request_max_body_size=1) + async def handler_one(request: Request) -> None: + await request.body() + + @post("/two", request_max_body_size=1) + async def handler_two(body: bytes) -> None: + pass + + with create_test_client([handler_one, handler_two]) as client: + response = client.post("/one", headers={"content-length": "2"}, content=b"ab") + assert response.status_code == HTTP_413_REQUEST_ENTITY_TOO_LARGE + + response = client.post("/two", headers={"content-length": "2"}, content=b"ab") + assert response.status_code == HTTP_413_REQUEST_ENTITY_TOO_LARGE + + +def test_request_body_exceeds_max_request_body_size_chunked() -> None: + @post("/one", request_max_body_size=1) + async def handler_one(request: Request) -> None: + assert request.headers["transfer-encoding"] == "chunked" + await request.body() + + @post("/two", request_max_body_size=1) + async def handler_two(body: bytes, request: Request) -> None: + assert request.headers["transfer-encoding"] == "chunked" + await request.body() + + def generator() -> Generator[bytes, None, None]: + yield b"1" + yield b"2" + + with create_test_client([handler_one, handler_two]) as client: + response = client.post("/one", content=generator()) + assert response.status_code == HTTP_413_REQUEST_ENTITY_TOO_LARGE + + response = client.post("/two", content=generator()) + assert response.status_code == HTTP_413_REQUEST_ENTITY_TOO_LARGE + + +def test_request_content_length() -> None: + @post("/") + def handler(request: Request) -> dict: + return {"content-length": request.content_length} + + with create_test_client([handler]) as client: + assert client.post("/", content=b"1").json() == {"content-length": 1} + + +def test_request_invalid_content_length() -> None: + @post("/") + def handler(request: Request) -> dict: + return {"content-length": request.content_length} + + with create_test_client([handler]) as client: + response = client.post("/", content=b"1", headers={"content-length": "a"}) + assert response.status_code == HTTP_400_BAD_REQUEST + assert response.json() == {"detail": "Invalid content-length: 'a'", "status_code": 400} diff --git a/tests/unit/test_handlers/test_http_handlers/test_resolution.py b/tests/unit/test_handlers/test_http_handlers/test_resolution.py new file mode 100644 index 0000000000..2f328005ea --- /dev/null +++ b/tests/unit/test_handlers/test_http_handlers/test_resolution.py @@ -0,0 +1,66 @@ +import pytest + +from litestar import Controller, Litestar, Router, post +from litestar.exceptions import ImproperlyConfiguredException +from litestar.types import Empty + + +def test_resolve_request_max_body_size() -> None: + @post("/1") + def router_handler() -> None: + pass + + @post("/2") + def app_handler() -> None: + pass + + class MyController(Controller): + request_max_body_size = 2 + + @post("/3") + def controller_handler(self) -> None: + pass + + router = Router("/", route_handlers=[router_handler], request_max_body_size=1) + app = Litestar(route_handlers=[app_handler, router, MyController], request_max_body_size=3) + assert router_handler.resolve_request_max_body_size() == 1 + assert app_handler.resolve_request_max_body_size() == 3 + assert ( + next(r for r in app.routes if r.path == "/3").route_handler_map["POST"][0].resolve_request_max_body_size() == 2 # type: ignore[union-attr] + ) + + +def test_resolve_request_max_body_size_none() -> None: + @post("/1", request_max_body_size=None) + def router_handler() -> None: + pass + + Litestar([router_handler]) + assert router_handler.resolve_request_max_body_size() is None + + +def test_resolve_request_max_body_size_app_default() -> None: + @post("/") + def router_handler() -> None: + pass + + app = Litestar(route_handlers=[router_handler]) + + assert router_handler.resolve_request_max_body_size() == app.request_max_body_size == 10_000_000 + + +def test_resolve_request_max_body_size_empty_on_all_layers_raises() -> None: + @post("/") + def handler_one() -> None: + pass + + Litestar([handler_one], request_max_body_size=Empty) # type: ignore[arg-type] + with pytest.raises(ImproperlyConfiguredException): + handler_one.resolve_request_max_body_size() + + @post("/") + def handler_two() -> None: + pass + + with pytest.raises(ImproperlyConfiguredException): + handler_two.resolve_request_max_body_size()