Skip to content

Commit

Permalink
fix(OpenAPI): Correctly handle type keyword (#3715)
Browse files Browse the repository at this point in the history
* support type keyword

* make pyright happy

* fix tests

* fix coverage upload
  • Loading branch information
provinzkraut authored Sep 3, 2024
1 parent d8b7d6e commit c6173a3
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 3 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ jobs:
with:
name: coverage-data
path: .coverage.pydantic_v1
include-hidden-files: true

upload-test-coverage:
runs-on: ubuntu-latest
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,4 @@ jobs:
with:
name: coverage-data
path: .coverage.${{ inputs.python-version }}
include-hidden-files: true
12 changes: 12 additions & 0 deletions litestar/_openapi/schema_generation/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,8 @@ def for_field_definition(self, field_definition: FieldDefinition) -> Schema | Re

if field_definition.is_new_type:
result = self.for_new_type(field_definition)
elif field_definition.is_type_alias_type:
result = self.for_type_alias_type(field_definition)
elif plugin_for_annotation := self.get_plugin_for(field_definition):
result = self.for_plugin(field_definition, plugin_for_annotation)
elif _should_create_enum_schema(field_definition):
Expand Down Expand Up @@ -366,6 +368,16 @@ def for_new_type(self, field_definition: FieldDefinition) -> Schema | Reference:
)
)

def for_type_alias_type(self, field_definition: FieldDefinition) -> Schema | Reference:
return self.for_field_definition(
FieldDefinition.from_kwarg(
annotation=field_definition.annotation.__value__,
name=field_definition.name,
default=field_definition.default,
kwarg_definition=field_definition.kwarg_definition,
)
)

@staticmethod
def for_upload_file(field_definition: FieldDefinition) -> Schema:
"""Create schema for UploadFile.
Expand Down
17 changes: 16 additions & 1 deletion litestar/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,17 @@
from typing import Any, AnyStr, Callable, Collection, ForwardRef, Literal, Mapping, Protocol, Sequence, TypeVar, cast

from msgspec import UnsetType
from typing_extensions import NewType, NotRequired, Required, Self, get_args, get_origin, get_type_hints, is_typeddict
from typing_extensions import (
NewType,
NotRequired,
Required,
Self,
TypeAliasType,
get_args,
get_origin,
get_type_hints,
is_typeddict,
)

from litestar.exceptions import ImproperlyConfiguredException, LitestarWarning
from litestar.openapi.spec import Example
Expand Down Expand Up @@ -363,6 +373,11 @@ def is_tuple(self) -> bool:
def is_new_type(self) -> bool:
return isinstance(self.annotation, NewType)

@property
def is_type_alias_type(self) -> bool:
"""Whether the annotation is a ``TypeAliasType``"""
return isinstance(self.annotation, TypeAliasType)

@property
def is_type_var(self) -> bool:
"""Whether the annotation is a TypeVar or not."""
Expand Down
31 changes: 30 additions & 1 deletion tests/unit/test_openapi/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import msgspec
import pytest
from msgspec import Struct
from typing_extensions import Annotated, TypeAlias
from typing_extensions import Annotated, TypeAlias, TypeAliasType

from litestar import Controller, MediaType, get, post
from litestar._openapi.schema_generation.plugins import openapi_schema_plugins
Expand Down Expand Up @@ -615,3 +615,32 @@ async def handler(dep: str) -> None:
assert param.name == f"param{i}"
assert param.required is True
assert param.param_in is ParamType.PATH


def test_type_alias_type() -> None:
@get("/")
def handler(query_param: Annotated[TypeAliasType("IntAlias", int), Parameter(description="foo")]) -> None: # type: ignore[valid-type]
pass

app = Litestar([handler])
param = app.openapi_schema.paths["/"].get.parameters[0] # type: ignore[index, union-attr]
assert param.schema.type is OpenAPIType.INTEGER # type: ignore[union-attr]
# ensure other attributes than the plain type are carried over correctly
assert param.description == "foo"


@pytest.mark.skipif(sys.version_info < (3, 12), reason="type keyword not available before 3.12")
def test_type_alias_type_keyword() -> None:
ctx: Dict[str, Any] = {}
exec("type IntAlias = int", ctx, None)
annotation = ctx["IntAlias"]

@get("/")
def handler(query_param: Annotated[annotation, Parameter(description="foo")]) -> None: # type: ignore[valid-type]
pass

app = Litestar([handler])
param = app.openapi_schema.paths["/"].get.parameters[0] # type: ignore[union-attr, index]
assert param.schema.type is OpenAPIType.INTEGER # type: ignore[union-attr]
# ensure other attributes than the plain type are carried over correctly
assert param.description == "foo"
16 changes: 15 additions & 1 deletion tests/unit/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import annotated_types
import msgspec
import pytest
from typing_extensions import Annotated, NotRequired, Required, TypedDict, get_type_hints
from typing_extensions import Annotated, NotRequired, Required, TypeAliasType, TypedDict, get_type_hints

from litestar import get
from litestar.exceptions import LitestarWarning
Expand Down Expand Up @@ -476,3 +476,17 @@ def handler(foo: Annotated[int, Parameter(default=1)]) -> None:
(record,) = warnings
assert record.category == DeprecationWarning
assert "Deprecated default value specification" in str(record.message)


def test_is_type_alias_type() -> None:
field_definition = FieldDefinition.from_annotation(TypeAliasType("IntAlias", int)) # pyright: ignore
assert field_definition.is_type_alias_type


@pytest.mark.skipif(sys.version_info < (3, 12), reason="type keyword not available before 3.12")
def test_unwrap_type_alias_type_keyword() -> None:
ctx: dict[str, Any] = {}
exec("type IntAlias = int", ctx, None)
annotation = ctx["IntAlias"]
field_definition = FieldDefinition.from_annotation(annotation)
assert field_definition.is_type_alias_type

0 comments on commit c6173a3

Please sign in to comment.