Skip to content

Commit

Permalink
Properly check for typing_extensions variant of TypeAliasType (py…
Browse files Browse the repository at this point in the history
  • Loading branch information
Daraan authored Oct 25, 2024
1 parent 12f89be commit 13979e9
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 9 deletions.
6 changes: 3 additions & 3 deletions pydantic/_internal/_core_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@

from pydantic_core import CoreSchema, core_schema
from pydantic_core import validate_core_schema as _validate_core_schema
from typing_extensions import TypeAliasType, TypeGuard, get_args, get_origin
from typing_extensions import TypeGuard, get_args, get_origin

from . import _repr
from ._typing_extra import is_generic_alias
from ._typing_extra import TYPE_ALIAS_TYPES, is_generic_alias

AnyFunctionSchema = Union[
core_schema.AfterValidatorFunctionSchema,
Expand Down Expand Up @@ -85,7 +85,7 @@ def get_type_ref(type_: type[Any], args_override: tuple[type[Any], ...] | None =
args = generic_metadata['args'] or args

module_name = getattr(origin, '__module__', '<No __module__>')
if isinstance(origin, TypeAliasType):
if isinstance(origin, TYPE_ALIAS_TYPES):
type_ref = f'{module_name}.{origin.__name__}:{id(origin)}'
else:
try:
Expand Down
15 changes: 11 additions & 4 deletions pydantic/_internal/_generate_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,14 @@
from ._mock_val_ser import MockCoreSchema
from ._namespace_utils import NamespacesTuple, NsResolver
from ._schema_generation_shared import CallbackGetCoreSchemaHandler
from ._typing_extra import get_cls_type_hints, is_annotated, is_finalvar, is_self_type, is_zoneinfo_type
from ._typing_extra import (
TYPE_ALIAS_TYPES,
get_cls_type_hints,
is_annotated,
is_finalvar,
is_self_type,
is_zoneinfo_type,
)
from ._utils import lenient_issubclass, smart_deepcopy
from ._validate_call import VALIDATE_CALL_SUPPORTED_TYPES, ValidateCallSupportedTypes

Expand Down Expand Up @@ -986,7 +993,7 @@ def match_type(self, obj: Any) -> core_schema.CoreSchema: # noqa: C901
return self._sequence_schema(Any)
elif obj in DICT_TYPES:
return self._dict_schema(Any, Any)
elif isinstance(obj, TypeAliasType):
elif isinstance(obj, TYPE_ALIAS_TYPES):
return self._type_alias_type_schema(obj)
elif obj is type:
return self._type_schema()
Expand Down Expand Up @@ -1050,7 +1057,7 @@ def _match_generic_type(self, obj: Any, origin: Any) -> CoreSchema: # noqa: C90
if from_property is not None:
return from_property

if isinstance(origin, TypeAliasType):
if isinstance(origin, TYPE_ALIAS_TYPES):
return self._type_alias_type_schema(obj)
elif _typing_extra.origin_is_union(origin):
return self._union_schema(obj)
Expand Down Expand Up @@ -1675,7 +1682,7 @@ def _subclass_schema(self, type_: Any) -> core_schema.CoreSchema:

if type_param == Any:
return self._type_schema()
elif isinstance(type_param, TypeAliasType):
elif isinstance(type_param, TYPE_ALIAS_TYPES):
return self.generate_schema(typing.Type[type_param.__value__])
elif isinstance(type_param, typing.TypeVar):
if type_param.__bound__:
Expand Down
10 changes: 8 additions & 2 deletions pydantic/_internal/_typing_extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ def origin_is_union(tp: type[Any] | None) -> bool:

NONE_TYPES: tuple[Any, ...] = (None, NoneType, *(tp[None] for tp in LITERAL_TYPES))

# should check for both variant of types for typing_extensions > 4.12.2
# https://typing-extensions.readthedocs.io/en/latest/#runtime-use-of-types
TYPE_ALIAS_TYPES: tuple[type, ...] = (
(TypeAliasType, typing.TypeAliasType) if hasattr(typing, 'TypeAliasType') else (TypeAliasType,)
)


TypeVarType = Any # since mypy doesn't allow the use of TypeVar as a type

Expand Down Expand Up @@ -594,8 +600,8 @@ def is_dataclass(_cls: type[Any]) -> TypeGuard[type[StandardDataclass]]:
return dataclasses.is_dataclass(_cls)


def origin_is_type_alias_type(origin: Any) -> TypeGuard[TypeAliasType]:
return isinstance(origin, TypeAliasType)
def origin_is_type_alias_type(origin: Any) -> TypeGuard[TypeAliasType | typing.TypeAliasType]:
return isinstance(origin, TYPE_ALIAS_TYPES)


if sys.version_info >= (3, 10):
Expand Down

0 comments on commit 13979e9

Please sign in to comment.