diff --git a/litestar/template/config.py b/litestar/template/config.py index d2aa87c302..ecff5d3f58 100644 --- a/litestar/template/config.py +++ b/litestar/template/config.py @@ -54,4 +54,8 @@ def to_engine(self) -> EngineType: @cached_property def engine_instance(self) -> EngineType: """Return the template engine instance.""" - return self.to_engine() if self.instance is None else self.instance + if self.instance is None: + return self.to_engine() + if callable(self.engine_callback): + self.engine_callback(self.instance) + return self.instance diff --git a/tests/unit/test_template/test_template.py b/tests/unit/test_template/test_template.py index b0d70d1f11..13f9c64e83 100644 --- a/tests/unit/test_template/test_template.py +++ b/tests/unit/test_template/test_template.py @@ -2,9 +2,10 @@ import sys from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import pytest +from jinja2 import DictLoader, Environment from litestar import Litestar, MediaType, get from litestar.contrib.jinja import JinjaTemplateEngine @@ -52,6 +53,39 @@ def callback(engine: TemplateEngineProtocol) -> None: assert received_engine is app.template_engine +def test_engine_passed_to_callback_custom_env(tmp_path: Path) -> None: + received_engine: JinjaTemplateEngine | None = None + + def callback(engine: TemplateEngineProtocol) -> None: + nonlocal received_engine + assert isinstance(engine, JinjaTemplateEngine), "Engine must be a JinjaTemplateEngine" + received_engine = engine + engine.register_template_callable( + key="check_context_key", + template_callable=my_template_function, + ) + + def my_template_function(ctx: dict[str, Any]) -> str: + return ctx.get("my_context_key", "nope") + + my_custom_env = Environment(loader=DictLoader({"index.html": "check_context_key: {{ check_context_key() }}"})) + template_config = TemplateConfig( + instance=JinjaTemplateEngine.from_environment(my_custom_env), + engine_callback=callback, + ) + app = Litestar(template_config=template_config) + + @get("/") + def handler() -> Template: + return Template(template_name="index.html") + + assert received_engine is not None + assert received_engine is app.template_engine + with create_test_client(route_handlers=[handler], template_config=template_config) as client: + response = client.get("/") + assert response.text == "check_context_key: nope" + + @pytest.mark.parametrize("engine", (JinjaTemplateEngine, MakoTemplateEngine, MiniJinjaTemplateEngine)) def test_engine_instance(engine: type[TemplateEngineProtocol], tmp_path: Path) -> None: engine_instance = engine(directory=tmp_path, engine_instance=None)