diff --git a/docs/middleware.md b/docs/middleware.md index 92ac5886a..23768d7d1 100644 --- a/docs/middleware.md +++ b/docs/middleware.md @@ -106,8 +106,7 @@ The following arguments are supported: * `max_age` - Session expiry time in seconds. Defaults to 2 weeks. If set to `None` then the cookie will last as long as the browser session. * `same_site` - SameSite flag prevents the browser from sending session cookie along with cross-site requests. Defaults to `'lax'`. * `https_only` - Indicate that Secure flag should be set (can be used with HTTPS only). Defaults to `False`. -* `domain` - Domain of the cookie used to share cookie between subdomains or cross-domains. The browser defaults the domain to the same host that set the cookie, excluding subdomains [refrence](https://developer.mozilla.org/en-US/docs/Web/HTTP/Cookies#domain_attribute). - +* `domain` - Domain of the cookie used to share cookie between subdomains or cross-domains. The browser defaults the domain to the same host that set the cookie, excluding subdomains [refrence](https://developer.mozilla.org/en-US/docs/Web/HTTP/Cookies#domain_attribute). ## HTTPSRedirectMiddleware @@ -184,6 +183,31 @@ The following arguments are supported: The middleware won't GZip responses that already have a `Content-Encoding` set, to prevent them from being encoded twice. +## LimitBodySizeMiddleware + +Limits the body size of incoming requests. If the incoming request has a body +larger than the limit, then a `413 Content Too Large` response will be sent. + +```python +from starlette.applications import Starlette +from starlette.middleware import Middleware +from starlette.middleware.limits import LimitBodySizeMiddleware + + +routes = ... + +middleware = [ + Middleware(LimitBodySizeMiddleware, max_body_size=1024) +] + +app = Starlette(routes=routes, middleware=middleware) +``` + +The following arguments are supported: + +* `max_body_size` - Send a "413 - Content Too Large" on requests that surpass the maximum allowed body size. + Defaults to `2.5 * 1024 * 1024` bytes (2.5MB). + ## BaseHTTPMiddleware An abstract class that allows you to write ASGI middleware against a request/response @@ -573,7 +597,7 @@ import time class MonitoringMiddleware: def __init__(self, app): self.app = app - + async def __call__(self, scope, receive, send): start = time.time() try: diff --git a/starlette/middleware/limits.py b/starlette/middleware/limits.py new file mode 100644 index 000000000..5e625f611 --- /dev/null +++ b/starlette/middleware/limits.py @@ -0,0 +1,74 @@ +"""Middleware that limits the body size of incoming requests.""" +from starlette.datastructures import Headers +from starlette.responses import PlainTextResponse +from starlette.types import ASGIApp, Message, Receive, Scope, Send + +DEFAULT_MAX_BODY_SIZE = 2_621_440 # 2.5MB +MAX_BODY_SIZE_KEY = "starlette.max_body_size" + + +class ContentTooLarge(Exception): + def __init__(self, max_body_size: int) -> None: + self.max_body_size = max_body_size + + +class SetBodySizeLimit: + def __init__(self, app: ASGIApp, max_body_size: int) -> None: + self.app = app + self.max_body_size = max_body_size + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + scope[MAX_BODY_SIZE_KEY] = self.max_body_size + await self.app(scope, receive, send) + +class LimitBodySizeMiddleware: + def __init__( + self, app: ASGIApp + ) -> None: + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] != "http": # pragma: no cover + return await self.app(scope, receive, send) + + total_size = 0 + response_started = False + headers = Headers(scope=scope) + content_length = headers.get("content-length") + + max_body_size = scope.get(MAX_BODY_SIZE_KEY, None) + if not max_body_size: + return await self.app(scope, receive, send) + + async def wrap_send(message: Message) -> None: + nonlocal response_started + if message["type"] == "http.response.start": + response_started = True + await send(message) + + async def wrap_receive() -> Message: + nonlocal total_size + + if content_length is not None: + if int(content_length) > max_body_size: + raise ContentTooLarge(max_body_size) + + message = await receive() + + if message["type"] == "http.request": + chunk_size = len(message["body"]) + total_size += chunk_size + if total_size > max_body_size: + raise ContentTooLarge(max_body_size) + + return message + + try: + await self.app(scope, wrap_receive, wrap_send) + except ContentTooLarge as exc: + # NOTE: If response has already started, we can't return a 413, because the + # headers have already been sent. + if not response_started: + response = PlainTextResponse("Content Too Large", status_code=413) + return await response(scope, receive, send) + raise exc diff --git a/tests/middleware/test_limits.py b/tests/middleware/test_limits.py new file mode 100644 index 000000000..204b1c78e --- /dev/null +++ b/tests/middleware/test_limits.py @@ -0,0 +1,174 @@ +from typing import AsyncGenerator, Callable + +import pytest + +from starlette.applications import Starlette +from starlette.middleware import Middleware +from starlette.middleware.limits import ContentTooLarge, LimitBodySizeMiddleware, SetBodySizeLimit +from starlette.requests import Request +from starlette.responses import Response +from starlette.routing import Mount, Route +from starlette.testclient import TestClient +from starlette.types import ASGIApp, Message, Receive, Scope, Send + + +async def echo_app(scope: Scope, receive: Receive, send: Send) -> None: + while True: + message = await receive() + more_body = message.get("more_body", False) + body = message.get("body", b"") + + await send({"type": "http.response.start", "status": 200, "headers": []}) + await send({"type": "http.response.body", "body": body, "more_body": more_body}) + + if not more_body: + break + + +app = SetBodySizeLimit(LimitBodySizeMiddleware(echo_app), max_body_size=1024) + + +def test_no_op(test_client_factory: Callable[..., TestClient]) -> None: + client = test_client_factory(app) + + response = client.post("/", content="Small payload") + assert response.status_code == 200 + assert response.text == "Small payload" + + +def test_content_too_large(test_client_factory: Callable[..., TestClient]) -> None: + client = test_client_factory(app) + + response = client.post("/", content="X" * 1025) + assert response.status_code == 413 + assert response.text == "Content Too Large" + + +def test_content_too_large_on_streaming_body( + test_client_factory: Callable[..., TestClient] +) -> None: + client = test_client_factory(app) + + response = client.post("/", content=[b"X" * 1025]) + assert response.status_code == 413 + assert response.text == "Content Too Large" + + +@pytest.mark.anyio +async def test_content_too_large_on_started_response() -> None: + scope: Scope = {"type": "http", "method": "POST", "path": "/", "headers": []} + + async def receive() -> AsyncGenerator[Message, None]: + yield {"type": "http.request", "body": b"X" * 1024, "more_body": True} + yield {"type": "http.request", "body": b"X", "more_body": False} + + async def send(message: Message) -> None: + ... + + rcv = receive() + + with pytest.raises(ContentTooLarge) as ctx: + await app(scope, rcv.__anext__, send) + assert ctx.value.max_body_size == 1024 + + await rcv.aclose() + + +async def read_body_endpoint(request: Request) -> Response: + body = b"" + async for chunk in request.stream(): + body += chunk + return Response(content=body) + + +def test_content_too_large_on_starlette( + test_client_factory: Callable[..., TestClient] +) -> None: + app = Starlette( + routes=[Mount("/", routes=[Route("/", read_body_endpoint, methods=["POST"])], middleware=[Middleware(LimitBodySizeMiddleware)])], + middleware=[Middleware(LimitBodySizeMiddleware, max_body_size=1024)], + ) + client = test_client_factory(app) + + response = client.post("/", content=b"X" * 1024) + assert response.status_code == 200 + assert response.text == "X" * 1024 + + response = client.post("/", content=[b"X" * 1024, b"X"]) + assert response.status_code == 413 + assert response.text == "Content Too Large" + + +def test_content_too_large_and_content_length_mismatch( + test_client_factory: Callable[..., TestClient] +) -> None: + client = test_client_factory(app) + + response = client.post("/", content="X" * 1025, headers={"Content-Length": "1024"}) + assert response.status_code == 413 + assert response.text == "Content Too Large" + + +def test_inner_middleware_overrides_outer_middleware( + test_client_factory: Callable[..., TestClient] +) -> None: + class CopyScopeMiddleware: + def __init__(self, app: ASGIApp) -> None: + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + scope = dict(scope) + await self.app(scope, receive, send) + + outer_app = SetBodySizeLimit( + CopyScopeMiddleware( + SetBodySizeLimit( + LimitBodySizeMiddleware(echo_app), + max_body_size=2048, + ) + ), + max_body_size=1024, + ) + + client = test_client_factory(outer_app) + + response = client.post("/", content="X" * 2049) + assert response.status_code == 413 + assert response.text == "Content Too Large" + + response = client.post("/", content="X" * 2048) + assert response.status_code == 200 + assert response.text == "X" * 2048 + + +def test_multiple_middleware_on_starlette( + test_client_factory: Callable[..., TestClient] +) -> None: + app = Starlette( + routes=[ + Route("/outer", read_body_endpoint, methods=["POST"]), + Mount( + "/inner", + routes=[Route("/", read_body_endpoint, methods=["POST"])], + middleware=[Middleware(LimitBodySizeMiddleware, max_body_size=2048)], + ), + ], + middleware=[Middleware(LimitBodySizeMiddleware, max_body_size=1024)], + ) + client = test_client_factory(app) + + response = client.post("/outer", content="X" * 1024) + assert response.status_code == 200 + assert response.text == "X" * 1024 + + response = client.post("/outer", content="X" * 1025) + assert response.status_code == 413 + assert response.text == "Content Too Large" + + response = client.post("/inner", content="X" * 2048) + assert response.status_code == 200 + assert response.text == "X" * 2048 + + response = client.post("/inner", content="X" * 2049) + assert response.status_code == 413 + assert response.text == "Content Too Large"