Skip to content

Commit

Permalink
fix: #1301 - Apply compression before caching (#2393)
Browse files Browse the repository at this point in the history
* fix: #1301 - Apply compression before caching

Signed-off-by: Janek Nouvertné <[email protected]>

* fix typing

Signed-off-by: Janek Nouvertné <[email protected]>

---------

Signed-off-by: Janek Nouvertné <[email protected]>
  • Loading branch information
provinzkraut authored Oct 1, 2023
1 parent 1b47eec commit 61b71d4
Show file tree
Hide file tree
Showing 7 changed files with 171 additions and 36 deletions.
6 changes: 6 additions & 0 deletions litestar/_asgi/routing_trie/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,8 @@ def build_route_middleware_stack(
from litestar.middleware.allowed_hosts import AllowedHostsMiddleware
from litestar.middleware.compression import CompressionMiddleware
from litestar.middleware.csrf import CSRFMiddleware
from litestar.middleware.response_cache import ResponseCacheMiddleware
from litestar.routes import HTTPRoute

# we wrap the route.handle method in the ExceptionHandlerMiddleware
asgi_handler = wrap_in_exception_handler(
Expand All @@ -197,6 +199,10 @@ def build_route_middleware_stack(

if app.compression_config:
asgi_handler = CompressionMiddleware(app=asgi_handler, config=app.compression_config)

if isinstance(route, HTTPRoute) and any(r.cache for r in route.route_handlers):
asgi_handler = ResponseCacheMiddleware(app=asgi_handler, config=app.response_cache_config)

if app.allowed_hosts:
asgi_handler = AllowedHostsMiddleware(app=asgi_handler, config=app.allowed_hosts)

Expand Down
1 change: 1 addition & 0 deletions litestar/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
SCOPE_STATE_DEPENDENCY_CACHE: Final = "dependency_cache"
SCOPE_STATE_NAMESPACE: Final = "__litestar__"
SCOPE_STATE_RESPONSE_COMPRESSED: Final = "response_compressed"
SCOPE_STATE_IS_CACHED: Final = "is_cached"
SKIP_VALIDATION_NAMES: Final = {"request", "socket", "scope", "receive", "send"}
UNDEFINED_SENTINELS: Final = {Signature.empty, Empty, Ellipsis, MISSING, UnsetType}
WEBSOCKET_CLOSE: Final = "websocket.close"
Expand Down
11 changes: 9 additions & 2 deletions litestar/middleware/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from io import BytesIO
from typing import TYPE_CHECKING, Any, Literal, Optional

from litestar.constants import SCOPE_STATE_RESPONSE_COMPRESSED
from litestar.constants import SCOPE_STATE_IS_CACHED, SCOPE_STATE_RESPONSE_COMPRESSED
from litestar.datastructures import Headers, MutableScopeHeaders
from litestar.enums import CompressionEncoding, ScopeType
from litestar.exceptions import MissingDependencyException
from litestar.middleware.base import AbstractMiddleware
from litestar.utils import Ref, set_litestar_scope_state
from litestar.utils import Ref, get_litestar_scope_state, set_litestar_scope_state

__all__ = ("CompressionFacade", "CompressionMiddleware")

Expand Down Expand Up @@ -176,6 +176,8 @@ def create_compression_send_wrapper(
initial_message = Ref[Optional["HTTPResponseStartEvent"]](None)
started = Ref[bool](False)

_own_encoding = compression_encoding.encode("latin-1")

async def send_wrapper(message: Message) -> None:
"""Handle and compresses the HTTP Message with brotli.
Expand All @@ -187,6 +189,11 @@ async def send_wrapper(message: Message) -> None:
initial_message.value = message
return

if initial_message.value and get_litestar_scope_state(scope, SCOPE_STATE_IS_CACHED):
await send(initial_message.value)
await send(message)
return

if initial_message.value and message["type"] == "http.response.body":
body = message["body"]
more_body = message.get("more_body")
Expand Down
48 changes: 48 additions & 0 deletions litestar/middleware/response_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from __future__ import annotations

from msgspec.msgpack import encode as encode_msgpack

from litestar.enums import ScopeType
from litestar.utils import get_litestar_scope_state

from .base import AbstractMiddleware

__all__ = ["ResponseCacheMiddleware"]

from typing import TYPE_CHECKING, cast

from litestar import Request
from litestar.constants import SCOPE_STATE_IS_CACHED

if TYPE_CHECKING:
from litestar.config.response_cache import ResponseCacheConfig
from litestar.handlers import HTTPRouteHandler
from litestar.types import ASGIApp, Message, Receive, Scope, Send


class ResponseCacheMiddleware(AbstractMiddleware):
def __init__(self, app: ASGIApp, config: ResponseCacheConfig) -> None:
self.config = config
super().__init__(app=app, scopes={ScopeType.HTTP})

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
route_handler = cast("HTTPRouteHandler", scope["route_handler"])
store = self.config.get_store_from_app(scope["app"])

expires_in: int | None = None
if route_handler.cache is True:
expires_in = self.config.default_expiration
elif route_handler.cache is not False and isinstance(route_handler.cache, int):
expires_in = route_handler.cache

messages = []

async def wrapped_send(message: Message) -> None:
if not get_litestar_scope_state(scope, SCOPE_STATE_IS_CACHED):
messages.append(message)
if message["type"] == "http.response.body" and not message["more_body"]:
key = (route_handler.cache_key_builder or self.config.key_builder)(Request(scope))
await store.set(key, encode_msgpack(messages), expires_in=expires_in)
await send(message)

await self.app(scope, receive, wrapped_send)
48 changes: 15 additions & 33 deletions litestar/routes/http.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

import pickle
from itertools import chain
from typing import TYPE_CHECKING, Any, cast

from litestar.constants import DEFAULT_ALLOWED_CORS_HEADERS
from msgspec.msgpack import decode as _decode_msgpack_plain

from litestar.constants import DEFAULT_ALLOWED_CORS_HEADERS, SCOPE_STATE_IS_CACHED
from litestar.datastructures.headers import Headers
from litestar.datastructures.upload_file import UploadFile
from litestar.enums import HttpMethod, MediaType, ScopeType
Expand All @@ -13,6 +14,7 @@
from litestar.response import Response
from litestar.routes.base import BaseRoute
from litestar.status_codes import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST
from litestar.utils import set_litestar_scope_state

if TYPE_CHECKING:
from litestar._kwargs import KwargsModel
Expand Down Expand Up @@ -128,19 +130,10 @@ async def _get_response_for_request(
):
return response

response = await self._call_handler_function(
return await self._call_handler_function(
scope=scope, request=request, parameter_model=parameter_model, route_handler=route_handler
)

if route_handler.cache:
await self._set_cached_response(
response=response,
request=request,
route_handler=route_handler,
)

return response

async def _call_handler_function(
self, scope: Scope, request: Request, parameter_model: KwargsModel, route_handler: HTTPRouteHandler
) -> ASGIApp:
Expand Down Expand Up @@ -225,30 +218,19 @@ async def _get_cached_response(request: Request, route_handler: HTTPRouteHandler
cache_key = (route_handler.cache_key_builder or cache_config.key_builder)(request)
store = cache_config.get_store_from_app(request.app)

cached_response = await store.get(key=cache_key)

if cached_response:
return cast("ASGIApp", pickle.loads(cached_response)) # noqa: S301
if not (cached_response_data := await store.get(key=cache_key)):
return None

return None
# we use the regular msgspec.msgpack.decode here since we don't need any of
# the added decoders
messages = _decode_msgpack_plain(cached_response_data)

@staticmethod
async def _set_cached_response(
response: Response | ASGIApp, request: Request, route_handler: HTTPRouteHandler
) -> None:
"""Pickles and caches a response object."""
cache_config = request.app.response_cache_config
cache_key = (route_handler.cache_key_builder or cache_config.key_builder)(request)

expires_in: int | None = None
if route_handler.cache is True:
expires_in = cache_config.default_expiration
elif route_handler.cache is not False and isinstance(route_handler.cache, int):
expires_in = route_handler.cache

store = cache_config.get_store_from_app(request.app)
async def cached_response(scope: Scope, receive: Receive, send: Send) -> None:
set_litestar_scope_state(scope, SCOPE_STATE_IS_CACHED, True)
for message in messages:
await send(message)

await store.set(key=cache_key, value=pickle.dumps(response, pickle.HIGHEST_PROTOCOL), expires_in=expires_in)
return cached_response

def create_options_handler(self, path: str) -> HTTPRouteHandler:
"""Args:
Expand Down
70 changes: 69 additions & 1 deletion tests/e2e/test_response_caching.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import gzip
import random
from datetime import timedelta
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Type, Union
from unittest.mock import MagicMock
from uuid import uuid4

import msgspec
import pytest

from litestar import Litestar, Request, get
from litestar.config.compression import CompressionConfig
from litestar.config.response_cache import CACHE_FOREVER, ResponseCacheConfig
from litestar.enums import CompressionEncoding
from litestar.middleware.response_cache import ResponseCacheMiddleware
from litestar.stores.base import Store
from litestar.stores.memory import MemoryStore
from litestar.testing import TestClient, create_test_client
Expand Down Expand Up @@ -180,3 +185,66 @@ def handler() -> str:
assert response_two.text == mock.return_value

assert mock.call_count == 1


def test_does_not_apply_to_non_cached_routes(mock: MagicMock) -> None:
@get("/")
def handler() -> str:
return mock() # type: ignore[no-any-return]

with create_test_client([handler]) as client:
first_response = client.get("/")
second_response = client.get("/")

assert first_response.status_code == 200
assert second_response.status_code == 200
assert mock.call_count == 2


@pytest.mark.parametrize(
"cache,expect_applied",
[
(True, True),
(False, False),
(1, True),
(CACHE_FOREVER, True),
],
)
def test_middleware_not_applied_to_non_cached_routes(
cache: Union[bool, int, Type[CACHE_FOREVER]], expect_applied: bool
) -> None:
@get(path="/", cache=cache)
def handler() -> None:
...

client = create_test_client(route_handlers=[handler])
unpacked_middleware = []
cur = client.app.asgi_router.root_route_map_node.children["/"].asgi_handlers["GET"][0]
while hasattr(cur, "app"):
unpacked_middleware.append(cur)
cur = cur.app
unpacked_middleware.append(cur)

assert len([m for m in unpacked_middleware if isinstance(m, ResponseCacheMiddleware)]) == int(expect_applied)


async def test_compression_applies_before_cache() -> None:
return_value = "_litestar_" * 4000
mock = MagicMock(return_value=return_value)

@get(path="/", cache=True)
def handler_fn() -> str:
return mock() # type: ignore[no-any-return]

app = Litestar(
route_handlers=[handler_fn],
compression_config=CompressionConfig(backend="gzip"),
)

with TestClient(app) as client:
client.get("/", headers={"Accept-Encoding": str(CompressionEncoding.GZIP.value)})

stored_value = await app.response_cache_config.get_store_from_app(app).get("/")
assert stored_value
stored_messages = msgspec.msgpack.decode(stored_value)
assert gzip.decompress(stored_messages[1]["body"]).decode() == return_value
23 changes: 23 additions & 0 deletions tests/unit/test_middleware/test_compression_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,26 @@ async def fake_send(message: Message) -> None:
# second body message with more_body=True will be empty if zlib buffers output and is not flushed
await wrapped_send(HTTPResponseBodyEvent(type="http.response.body", body=b"abc", more_body=True))
assert mock.mock_calls[-1].args[0]["body"]


@pytest.mark.parametrize(
"backend, compression_encoding", (("brotli", CompressionEncoding.BROTLI), ("gzip", CompressionEncoding.GZIP))
)
def test_dont_recompress_cached(backend: Literal["gzip", "brotli"], compression_encoding: CompressionEncoding) -> None:
mock = MagicMock(return_value="_litestar_" * 4000)

@get(path="/", media_type=MediaType.TEXT, cache=True)
def handler_fn() -> str:
return mock() # type: ignore[no-any-return]

with create_test_client(
route_handlers=[handler_fn], compression_config=CompressionConfig(backend=backend)
) as client:
client.get("/", headers={"Accept-Encoding": str(compression_encoding.value)})
response = client.get("/", headers={"Accept-Encoding": str(compression_encoding.value)})

assert mock.call_count == 1
assert response.status_code == HTTP_200_OK
assert response.text == "_litestar_" * 4000
assert response.headers["Content-Encoding"] == compression_encoding
assert int(response.headers["Content-Length"]) < 40000

0 comments on commit 61b71d4

Please sign in to comment.