Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds support for fields with local dataclass types #180

Merged
merged 3 commits into from
Jan 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions mashumaro/core/meta/code/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,20 @@
is_hashable,
is_init_var,
is_literal,
is_local_type_name,
is_named_tuple,
is_optional,
is_type_var_any,
resolve_type_params,
substitute_type_params,
type_name,
)
from mashumaro.core.meta.types.common import FieldContext, NoneType, ValueSpec
from mashumaro.core.meta.types.common import (
FieldContext,
NoneType,
ValueSpec,
clean_id,
)
from mashumaro.core.meta.types.pack import PackerRegistry
from mashumaro.core.meta.types.unpack import (
SubtypeUnpackerBuilder,
Expand Down Expand Up @@ -210,6 +216,21 @@ def get_field_types(
) -> typing.Dict[str, typing.Any]:
return self.__get_field_types(include_extras=include_extras)

def get_type_name_identifier(
self,
typ: typing.Optional[typing.Type],
resolved_type_params: typing.Optional[
typing.Dict[typing.Type, typing.Type]
] = None,
) -> str:
field_type = type_name(typ, resolved_type_params=resolved_type_params)

if is_local_type_name(field_type):
field_type = clean_id(field_type)
self.ensure_object_imported(typ, field_type)

return field_type

@property # type: ignore
@lru_cache()
def dataclass_fields(self) -> typing.Dict[str, Field]:
Expand Down Expand Up @@ -306,6 +327,7 @@ def compile(self) -> None:
else:
print(f"{type_name(self.cls)}:")
print(code)

exec(code, self.globals, self.__dict__)

def evaluate_forward_ref(
Expand Down Expand Up @@ -1250,12 +1272,13 @@ def build(
) -> FieldUnpackerCodeBlock:
default = self.parent.get_field_default(fname)
has_default = default is not MISSING
field_type = type_name(
field_type = self.parent.get_type_name_identifier(
ftype,
resolved_type_params=self.parent.get_field_resolved_type_params(
fname
),
)

could_be_none = (
ftype in (typing.Any, type(None), None)
or is_type_var_any(self.parent.get_real_type(fname, ftype))
Expand Down
4 changes: 4 additions & 0 deletions mashumaro/core/meta/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,10 @@ def is_literal(typ: Type) -> bool:
return False


def is_local_type_name(type_name: str) -> bool:
return "<locals>" in type_name


def not_none_type_arg(
type_args: Tuple[Type, ...],
resolved_type_params: Optional[Dict[Type, Type]] = None,
Expand Down
11 changes: 8 additions & 3 deletions mashumaro/core/meta/types/common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import collections.abc
import re
import uuid
from abc import ABC, abstractmethod
from dataclasses import dataclass, field, replace
Expand Down Expand Up @@ -300,7 +301,11 @@ def random_hex() -> str:
return str(uuid.uuid4().hex)


_PY_VALID_ID_RE = re.compile(r"\W|^(?=\d)")


def clean_id(value: str) -> str:
for c in ".<>":
value = value.replace(c, "_")
return value
if not value:
return "_"

return _PY_VALID_ID_RE.sub("_", value)
9 changes: 6 additions & 3 deletions mashumaro/core/meta/types/pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,12 +300,13 @@ def pack_union(
with lines.indent("try:"):
lines.append(f"return {packer}")
lines.append("except Exception: pass")
field_type = type_name(
field_type = spec.builder.get_type_name_identifier(
spec.type,
resolved_type_params=spec.builder.get_field_resolved_type_params(
spec.field_ctx.name
),
)

if spec.builder.is_nailed:
lines.append(
"raise InvalidFieldValue("
Expand Down Expand Up @@ -355,10 +356,11 @@ def pack_literal(spec: ValueSpec) -> Expression:
spec.copy(type=value_type, expression="value")
)
if isinstance(literal_value, enum.Enum):
enum_type_name = type_name(
enum_type_name = spec.builder.get_type_name_identifier(
typ=value_type,
resolved_type_params=resolved_type_params,
)

with lines.indent(
f"if value == {enum_type_name}.{literal_value.name}:"
):
Expand All @@ -369,10 +371,11 @@ def pack_literal(spec: ValueSpec) -> Expression:
):
with lines.indent(f"if value == {literal_value!r}:"):
lines.append(f"return {packer}")
field_type = type_name(
field_type = spec.builder.get_type_name_identifier(
typ=spec.type,
resolved_type_params=resolved_type_params,
)

if spec.builder.is_nailed:
lines.append(
f"raise InvalidFieldValue('{spec.field_ctx.name}',"
Expand Down
42 changes: 32 additions & 10 deletions mashumaro/core/meta/types/unpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def _add_body(self, spec: ValueSpec, lines: CodeLines) -> None:
with lines.indent("try:"):
lines.append(f"return {unpacker}")
lines.append("except Exception: pass")
field_type = type_name(
field_type = spec.builder.get_type_name_identifier(
spec.type,
resolved_type_params=spec.builder.get_field_resolved_type_params(
spec.field_ctx.name
Expand Down Expand Up @@ -203,7 +203,11 @@ def get_method_prefix(self) -> str:
def _add_body(self, spec: ValueSpec, lines: CodeLines) -> None:
for literal_value in get_literal_values(spec.type):
if isinstance(literal_value, enum.Enum):
enum_type_name = type_name(type(literal_value))
lit_type = type(literal_value)
enum_type_name = spec.builder.get_type_name_identifier(
lit_type
)

with lines.indent(
f"if value == {enum_type_name}.{literal_value.name}.value:"
):
Expand Down Expand Up @@ -298,7 +302,7 @@ def _add_body(self, spec: ValueSpec, lines: CodeLines) -> None:
variants_map = self._get_variants_map(spec)
variants_attr_holder = self._get_variants_attr_holder(spec)
variants = self._get_variant_names_iterable(spec)
variants_type_expr = type_name(spec.type)
variants_type_expr = spec.builder.get_type_name_identifier(spec.type)

if variants_attr not in variants_attr_holder.__dict__:
setattr(variants_attr_holder, variants_attr, {})
Expand Down Expand Up @@ -565,7 +569,10 @@ def _unpack_annotated_serializable_type(
],
)
unpacker = UnpackerRegistry.get(spec.copy(type=value_type))
return f"{type_name(spec.type)}._deserialize({unpacker})"

field_type = spec.builder.get_type_name_identifier(spec.type)

return f"{field_type}._deserialize({unpacker})"


@register
Expand All @@ -578,7 +585,9 @@ def unpack_serializable_type(spec: ValueSpec) -> Optional[Expression]:
if spec.origin_type.__use_annotations__:
return _unpack_annotated_serializable_type(spec)
else:
return f"{type_name(spec.type)}._deserialize({spec.expression})"
field_type = spec.builder.get_type_name_identifier(spec.type)

return f"{field_type}._deserialize({spec.expression})"


@register
Expand All @@ -588,8 +597,12 @@ def unpack_generic_serializable_type(spec: ValueSpec) -> Optional[Expression]:
type_arg_names = ", ".join(
list(map(type_name, get_args(spec.type)))
)
field_type = spec.builder.get_type_name_identifier(
spec.origin_type
)

return (
f"{type_name(spec.type)}._deserialize({spec.expression}, "
f"{field_type}._deserialize({spec.expression}, "
f"[{type_arg_names}])"
)

Expand Down Expand Up @@ -990,7 +1003,9 @@ def unpack_named_tuple(spec: ValueSpec) -> Expression:
unpackers.append(unpacker)

if not defaults:
return f"{type_name(spec.type)}({', '.join(unpackers)})"
field_type = spec.builder.get_type_name_identifier(spec.type)

return f"{field_type}({', '.join(unpackers)})"

lines = CodeLines()
method_name = (
Expand All @@ -1015,7 +1030,10 @@ def unpack_named_tuple(spec: ValueSpec) -> Expression:
lines.append(f"fields.append({unpacker})")
with lines.indent("except IndexError:"):
lines.append("pass")
lines.append(f"return {type_name(spec.type)}(*fields)")

field_type = spec.builder.get_type_name_identifier(spec.type)

lines.append(f"return {field_type}(*fields)")
lines.append(
f"setattr({spec.cls_attrs_name}, '{method_name}', {method_name})"
)
Expand Down Expand Up @@ -1194,10 +1212,14 @@ def unpack_pathlike(spec: ValueSpec) -> Optional[Expression]:
spec.builder.ensure_module_imported(pathlib)
return f"{type_name(pathlib.PurePath)}({spec.expression})"
elif issubclass(spec.origin_type, os.PathLike):
return f"{type_name(spec.origin_type)}({spec.expression})"
field_type = spec.builder.get_type_name_identifier(spec.origin_type)

return f"{field_type}({spec.expression})"


@register
def unpack_enum(spec: ValueSpec) -> Optional[Expression]:
if issubclass(spec.origin_type, enum.Enum):
return f"{type_name(spec.origin_type)}({spec.expression})"
field_type = spec.builder.get_type_name_identifier(spec.origin_type)

return f"{field_type}({spec.expression})"
83 changes: 83 additions & 0 deletions tests/test_common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import dataclasses
from dataclasses import dataclass, field
from enum import Enum
from pathlib import PurePosixPath
from typing import Any, Literal, NamedTuple

import msgpack
import pytest
Expand Down Expand Up @@ -262,3 +265,83 @@ def test_kw_args_when_pos_arg_is_overridden_with_field():
assert loaded.pos2 == 2
assert loaded.pos3 == 3
assert loaded.kw1 == 4


def test_local_types():
@dataclass
class LocalDataclassType:
foo: int

class LocalNamedTupleType(NamedTuple):
foo: int

class LocalPathLike(PurePosixPath):
pass

class LocalEnumType(Enum):
FOO = "foo"

class LocalSerializableType(SerializableType):
@classmethod
def _deserialize(self, value):
return LocalSerializableType()

def _serialize(self) -> Any:
return {}

def __eq__(self, __value: object) -> bool:
return isinstance(__value, LocalSerializableType)

class LocalGenericSerializableType(GenericSerializableType):
@classmethod
def _deserialize(self, value, types):
return LocalGenericSerializableType()

def _serialize(self, types) -> Any:
return {}

def __eq__(self, __value: object) -> bool:
return isinstance(__value, LocalGenericSerializableType)

@dataclass
class DataClassWithLocalType(DataClassDictMixin):
x1: LocalDataclassType
x2: LocalNamedTupleType
x3: LocalPathLike
x4: LocalEnumType
x4_1: Literal[LocalEnumType.FOO]
x5: LocalSerializableType
x6: LocalGenericSerializableType

obj = DataClassWithLocalType(
x1=LocalDataclassType(foo=0),
x2=LocalNamedTupleType(foo=0),
x3=LocalPathLike("path/to/file"),
x4=LocalEnumType.FOO,
x4_1=LocalEnumType.FOO,
x5=LocalSerializableType(),
x6=LocalGenericSerializableType(),
)
assert obj.to_dict() == {
"x1": {"foo": 0},
"x2": [0],
"x3": "path/to/file",
"x4": "foo",
"x4_1": "foo",
"x5": {},
"x6": {},
}
assert (
DataClassWithLocalType.from_dict(
{
"x1": {"foo": 0},
"x2": [0],
"x3": "path/to/file",
"x4": "foo",
"x4_1": "foo",
"x5": {},
"x6": {},
}
)
== obj
)