Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add request_max_size option #2175

Closed
3 changes: 3 additions & 0 deletions starlette/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(
on_startup: typing.Optional[typing.Sequence[typing.Callable]] = None,
on_shutdown: typing.Optional[typing.Sequence[typing.Callable]] = None,
lifespan: typing.Optional[Lifespan["AppType"]] = None,
request_max_size: int = 2621440, # 2.5mb
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like in Django

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you permalink where you got this from in a comment below here just for the record? I'd also be okay with a code comment.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does flask do?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

) -> None:
# The lifespan context function is a newer style that replaces
# on_startup / on_shutdown handlers. Use one or the other, not both.
Expand All @@ -78,6 +79,7 @@ def __init__(
)
self.user_middleware = [] if middleware is None else list(middleware)
self.middleware_stack: typing.Optional[ASGIApp] = None
self.request_max_size = request_max_size
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be a good idea to start making these attributes private.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, all members of Starlette are public. I would not like to introduce a new style here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not true:

self._url = url

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough. I still think a lot of those shouldn't have been public in the first place and just because they are doesn't mean we should make this public as well.


def build_middleware_stack(self) -> ASGIApp:
debug = self.debug
Expand Down Expand Up @@ -117,6 +119,7 @@ def url_path_for(self, __name: str, **path_params: typing.Any) -> URLPath:

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
scope["app"] = self
scope["request_max_size"] = self.request_max_size
alex-oleshkevich marked this conversation as resolved.
Show resolved Hide resolved
if self.middleware_stack is None:
self.middleware_stack = self.build_middleware_stack()
await self.middleware_stack(scope, receive, send)
Expand Down
24 changes: 22 additions & 2 deletions starlette/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.responses import PlainTextResponse, RedirectResponse
from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send
from starlette.types import ASGIApp, Lifespan, Message, Receive, Scope, Send
from starlette.websockets import WebSocket, WebSocketClose


Expand Down Expand Up @@ -205,12 +205,14 @@ def __init__(
methods: typing.Optional[typing.List[str]] = None,
name: typing.Optional[str] = None,
include_in_schema: bool = True,
request_max_size: typing.Optional[int] = None,
) -> None:
assert path.startswith("/"), "Routed paths must start with '/'"
self.path = path
self.endpoint = endpoint
self.name = get_name(endpoint) if name is None else name
self.include_in_schema = include_in_schema
self.request_max_size = request_max_size

endpoint_handler = endpoint
while isinstance(endpoint_handler, functools.partial):
Expand Down Expand Up @@ -273,7 +275,25 @@ async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
)
await response(scope, receive, send)
else:
await self.app(scope, receive, send)
bytes_read = 0
max_body_size = self.request_max_size or scope.get("request_max_size")

async def receive_wrapper() -> Message:
nonlocal bytes_read
message = await receive()
if message["type"] != "http.request" or max_body_size is None:
return message

body = message.get("body", b"")
bytes_read += len(body)
if bytes_read > max_body_size:
raise HTTPException(
Copy link
Member Author

@alex-oleshkevich alex-oleshkevich Jun 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the exception good enough?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah fine by me

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please follow the same logic as above.

        if "app" in scope:
            raise HTTPException(status_code=413, "Request Entity Too Large")
        else:
            response = PlainTextResponse("Request Entity Too Large", status_code=404)

alex-oleshkevich marked this conversation as resolved.
Show resolved Hide resolved
status_code=413, detail="Request Entity Too Large"
)

return message

await self.app(scope, receive_wrapper, send)

def __eq__(self, other: typing.Any) -> bool:
return (
Expand Down
20 changes: 20 additions & 0 deletions tests/test_applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,26 @@ def test_500(test_client_factory):
assert response.json() == {"detail": "Server Error"}


def test_413(test_client_factory):
async def read_view(request):
content = await request.body()
return JSONResponse(content.decode())

app = Starlette(
request_max_size=10, # 10 bytes
routes=[Route("/", endpoint=read_view, methods=["POST"])],
)

client = test_client_factory(app, raise_server_exceptions=True)
response = client.post("/", data=b"youshallnotpass")
assert response.status_code == 413
assert response.text == "Request Entity Too Large"

response = client.post("/", data=b"ok")
assert response.status_code == 200
assert response.text == '"ok"'


def test_websocket_raise_websocket_exception(client):
with client.websocket_connect("/ws-raise-websocket") as session:
response = session.receive()
Expand Down
37 changes: 37 additions & 0 deletions tests/test_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1125,3 +1125,40 @@ async def startup() -> None:
... # pragma: nocover

router.on_event("startup")(startup)


def test_request_body_size_limit(test_client_factory):
async def read_view(request):
content = await request.body()
return JSONResponse(content.decode())

app = Starlette(
routes=[Route("/", endpoint=read_view, methods=["POST"], request_max_size=10)],
)

client = test_client_factory(app, raise_server_exceptions=True)
response = client.post("/", data=b"youshallnotpass")
assert response.status_code == 413
assert response.text == "Request Entity Too Large"

response = client.post("/", data=b"ok")
assert response.status_code == 200
assert response.text == '"ok"'


def test_request_body_size_limit_route_has_higher_precedense(test_client_factory):
async def read_view(request):
content = await request.body()
return JSONResponse(content.decode())

app = Starlette(
request_max_size=5,
routes=[Route("/", endpoint=read_view, methods=["POST"], request_max_size=24)],
)

# app caps at 5 bytes, route has 24 byte limit
# payload of size 12 bytes should pass
client = test_client_factory(app, raise_server_exceptions=True)
response = client.post("/", data=b"youshallpass") # 12 bytes
assert response.status_code == 200
assert response.text == '"youshallpass"'