diff --git a/mashumaro/core/meta/code/builder.py b/mashumaro/core/meta/code/builder.py index a3daba27..4e65ae5f 100644 --- a/mashumaro/core/meta/code/builder.py +++ b/mashumaro/core/meta/code/builder.py @@ -45,6 +45,7 @@ is_hashable, is_init_var, is_literal, + is_local_type_name, is_named_tuple, is_optional, is_type_var_any, @@ -52,7 +53,12 @@ substitute_type_params, type_name, ) -from mashumaro.core.meta.types.common import FieldContext, NoneType, ValueSpec +from mashumaro.core.meta.types.common import ( + FieldContext, + NoneType, + ValueSpec, + clean_id, +) from mashumaro.core.meta.types.pack import PackerRegistry from mashumaro.core.meta.types.unpack import ( SubtypeUnpackerBuilder, @@ -210,6 +216,21 @@ def get_field_types( ) -> typing.Dict[str, typing.Any]: return self.__get_field_types(include_extras=include_extras) + def get_type_name_identifier( + self, + typ: typing.Optional[typing.Type], + resolved_type_params: typing.Optional[ + typing.Dict[typing.Type, typing.Type] + ] = None, + ) -> str: + field_type = type_name(typ, resolved_type_params=resolved_type_params) + + if is_local_type_name(field_type): + field_type = clean_id(field_type) + self.ensure_object_imported(typ, field_type) + + return field_type + @property # type: ignore @lru_cache() def dataclass_fields(self) -> typing.Dict[str, Field]: @@ -306,6 +327,7 @@ def compile(self) -> None: else: print(f"{type_name(self.cls)}:") print(code) + exec(code, self.globals, self.__dict__) def evaluate_forward_ref( @@ -1250,12 +1272,13 @@ def build( ) -> FieldUnpackerCodeBlock: default = self.parent.get_field_default(fname) has_default = default is not MISSING - field_type = type_name( + field_type = self.parent.get_type_name_identifier( ftype, resolved_type_params=self.parent.get_field_resolved_type_params( fname ), ) + could_be_none = ( ftype in (typing.Any, type(None), None) or is_type_var_any(self.parent.get_real_type(fname, ftype)) diff --git a/mashumaro/core/meta/helpers.py b/mashumaro/core/meta/helpers.py index eee4ab8f..d73d696b 100644 --- a/mashumaro/core/meta/helpers.py +++ b/mashumaro/core/meta/helpers.py @@ -389,6 +389,10 @@ def is_literal(typ: Type) -> bool: return False +def is_local_type_name(type_name: str) -> bool: + return "" in type_name + + def not_none_type_arg( type_args: Tuple[Type, ...], resolved_type_params: Optional[Dict[Type, Type]] = None, diff --git a/mashumaro/core/meta/types/common.py b/mashumaro/core/meta/types/common.py index 40f7976b..42245203 100644 --- a/mashumaro/core/meta/types/common.py +++ b/mashumaro/core/meta/types/common.py @@ -1,4 +1,5 @@ import collections.abc +import re import uuid from abc import ABC, abstractmethod from dataclasses import dataclass, field, replace @@ -300,7 +301,11 @@ def random_hex() -> str: return str(uuid.uuid4().hex) +_PY_VALID_ID_RE = re.compile(r"\W|^(?=\d)") + + def clean_id(value: str) -> str: - for c in ".<>": - value = value.replace(c, "_") - return value + if not value: + return "_" + + return _PY_VALID_ID_RE.sub("_", value) diff --git a/mashumaro/core/meta/types/pack.py b/mashumaro/core/meta/types/pack.py index a7fc4ed0..1e181ae3 100644 --- a/mashumaro/core/meta/types/pack.py +++ b/mashumaro/core/meta/types/pack.py @@ -300,12 +300,13 @@ def pack_union( with lines.indent("try:"): lines.append(f"return {packer}") lines.append("except Exception: pass") - field_type = type_name( + field_type = spec.builder.get_type_name_identifier( spec.type, resolved_type_params=spec.builder.get_field_resolved_type_params( spec.field_ctx.name ), ) + if spec.builder.is_nailed: lines.append( "raise InvalidFieldValue(" @@ -355,10 +356,11 @@ def pack_literal(spec: ValueSpec) -> Expression: spec.copy(type=value_type, expression="value") ) if isinstance(literal_value, enum.Enum): - enum_type_name = type_name( + enum_type_name = spec.builder.get_type_name_identifier( typ=value_type, resolved_type_params=resolved_type_params, ) + with lines.indent( f"if value == {enum_type_name}.{literal_value.name}:" ): @@ -369,10 +371,11 @@ def pack_literal(spec: ValueSpec) -> Expression: ): with lines.indent(f"if value == {literal_value!r}:"): lines.append(f"return {packer}") - field_type = type_name( + field_type = spec.builder.get_type_name_identifier( typ=spec.type, resolved_type_params=resolved_type_params, ) + if spec.builder.is_nailed: lines.append( f"raise InvalidFieldValue('{spec.field_ctx.name}'," diff --git a/mashumaro/core/meta/types/unpack.py b/mashumaro/core/meta/types/unpack.py index b979d2a3..976cb0a2 100644 --- a/mashumaro/core/meta/types/unpack.py +++ b/mashumaro/core/meta/types/unpack.py @@ -173,7 +173,7 @@ def _add_body(self, spec: ValueSpec, lines: CodeLines) -> None: with lines.indent("try:"): lines.append(f"return {unpacker}") lines.append("except Exception: pass") - field_type = type_name( + field_type = spec.builder.get_type_name_identifier( spec.type, resolved_type_params=spec.builder.get_field_resolved_type_params( spec.field_ctx.name @@ -203,7 +203,11 @@ def get_method_prefix(self) -> str: def _add_body(self, spec: ValueSpec, lines: CodeLines) -> None: for literal_value in get_literal_values(spec.type): if isinstance(literal_value, enum.Enum): - enum_type_name = type_name(type(literal_value)) + lit_type = type(literal_value) + enum_type_name = spec.builder.get_type_name_identifier( + lit_type + ) + with lines.indent( f"if value == {enum_type_name}.{literal_value.name}.value:" ): @@ -298,7 +302,7 @@ def _add_body(self, spec: ValueSpec, lines: CodeLines) -> None: variants_map = self._get_variants_map(spec) variants_attr_holder = self._get_variants_attr_holder(spec) variants = self._get_variant_names_iterable(spec) - variants_type_expr = type_name(spec.type) + variants_type_expr = spec.builder.get_type_name_identifier(spec.type) if variants_attr not in variants_attr_holder.__dict__: setattr(variants_attr_holder, variants_attr, {}) @@ -565,7 +569,10 @@ def _unpack_annotated_serializable_type( ], ) unpacker = UnpackerRegistry.get(spec.copy(type=value_type)) - return f"{type_name(spec.type)}._deserialize({unpacker})" + + field_type = spec.builder.get_type_name_identifier(spec.type) + + return f"{field_type}._deserialize({unpacker})" @register @@ -578,7 +585,9 @@ def unpack_serializable_type(spec: ValueSpec) -> Optional[Expression]: if spec.origin_type.__use_annotations__: return _unpack_annotated_serializable_type(spec) else: - return f"{type_name(spec.type)}._deserialize({spec.expression})" + field_type = spec.builder.get_type_name_identifier(spec.type) + + return f"{field_type}._deserialize({spec.expression})" @register @@ -588,8 +597,12 @@ def unpack_generic_serializable_type(spec: ValueSpec) -> Optional[Expression]: type_arg_names = ", ".join( list(map(type_name, get_args(spec.type))) ) + field_type = spec.builder.get_type_name_identifier( + spec.origin_type + ) + return ( - f"{type_name(spec.type)}._deserialize({spec.expression}, " + f"{field_type}._deserialize({spec.expression}, " f"[{type_arg_names}])" ) @@ -990,7 +1003,9 @@ def unpack_named_tuple(spec: ValueSpec) -> Expression: unpackers.append(unpacker) if not defaults: - return f"{type_name(spec.type)}({', '.join(unpackers)})" + field_type = spec.builder.get_type_name_identifier(spec.type) + + return f"{field_type}({', '.join(unpackers)})" lines = CodeLines() method_name = ( @@ -1015,7 +1030,10 @@ def unpack_named_tuple(spec: ValueSpec) -> Expression: lines.append(f"fields.append({unpacker})") with lines.indent("except IndexError:"): lines.append("pass") - lines.append(f"return {type_name(spec.type)}(*fields)") + + field_type = spec.builder.get_type_name_identifier(spec.type) + + lines.append(f"return {field_type}(*fields)") lines.append( f"setattr({spec.cls_attrs_name}, '{method_name}', {method_name})" ) @@ -1194,10 +1212,14 @@ def unpack_pathlike(spec: ValueSpec) -> Optional[Expression]: spec.builder.ensure_module_imported(pathlib) return f"{type_name(pathlib.PurePath)}({spec.expression})" elif issubclass(spec.origin_type, os.PathLike): - return f"{type_name(spec.origin_type)}({spec.expression})" + field_type = spec.builder.get_type_name_identifier(spec.origin_type) + + return f"{field_type}({spec.expression})" @register def unpack_enum(spec: ValueSpec) -> Optional[Expression]: if issubclass(spec.origin_type, enum.Enum): - return f"{type_name(spec.origin_type)}({spec.expression})" + field_type = spec.builder.get_type_name_identifier(spec.origin_type) + + return f"{field_type}({spec.expression})" diff --git a/tests/test_common.py b/tests/test_common.py index f832cc43..d62f15fa 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -1,5 +1,8 @@ import dataclasses from dataclasses import dataclass, field +from enum import Enum +from pathlib import PurePosixPath +from typing import Any, Literal, NamedTuple import msgpack import pytest @@ -262,3 +265,83 @@ def test_kw_args_when_pos_arg_is_overridden_with_field(): assert loaded.pos2 == 2 assert loaded.pos3 == 3 assert loaded.kw1 == 4 + + +def test_local_types(): + @dataclass + class LocalDataclassType: + foo: int + + class LocalNamedTupleType(NamedTuple): + foo: int + + class LocalPathLike(PurePosixPath): + pass + + class LocalEnumType(Enum): + FOO = "foo" + + class LocalSerializableType(SerializableType): + @classmethod + def _deserialize(self, value): + return LocalSerializableType() + + def _serialize(self) -> Any: + return {} + + def __eq__(self, __value: object) -> bool: + return isinstance(__value, LocalSerializableType) + + class LocalGenericSerializableType(GenericSerializableType): + @classmethod + def _deserialize(self, value, types): + return LocalGenericSerializableType() + + def _serialize(self, types) -> Any: + return {} + + def __eq__(self, __value: object) -> bool: + return isinstance(__value, LocalGenericSerializableType) + + @dataclass + class DataClassWithLocalType(DataClassDictMixin): + x1: LocalDataclassType + x2: LocalNamedTupleType + x3: LocalPathLike + x4: LocalEnumType + x4_1: Literal[LocalEnumType.FOO] + x5: LocalSerializableType + x6: LocalGenericSerializableType + + obj = DataClassWithLocalType( + x1=LocalDataclassType(foo=0), + x2=LocalNamedTupleType(foo=0), + x3=LocalPathLike("path/to/file"), + x4=LocalEnumType.FOO, + x4_1=LocalEnumType.FOO, + x5=LocalSerializableType(), + x6=LocalGenericSerializableType(), + ) + assert obj.to_dict() == { + "x1": {"foo": 0}, + "x2": [0], + "x3": "path/to/file", + "x4": "foo", + "x4_1": "foo", + "x5": {}, + "x6": {}, + } + assert ( + DataClassWithLocalType.from_dict( + { + "x1": {"foo": 0}, + "x2": [0], + "x3": "path/to/file", + "x4": "foo", + "x4_1": "foo", + "x5": {}, + "x6": {}, + } + ) + == obj + )