Skip to content

Commit

Permalink
Switched from __dataclass_transform__() to typing.dataclass_transform…
Browse files Browse the repository at this point in the history
…() (#1158)
  • Loading branch information
superbobry authored Jul 6, 2023
1 parent 620cd59 commit 36d4762
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 46 deletions.
3 changes: 3 additions & 0 deletions changelog.d/1158.change.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Type stubs now use `typing.dataclass_transform` to decorate dataclass-like
decorators, instead of the non-standard `__dataclass_transform__` special
form, which is only supported by pyright.
19 changes: 2 additions & 17 deletions docs/extending.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,31 +94,16 @@ You can only use this trick to tell *Mypy* that a class is actually an *attrs* c

### Pyright

Generic decorator wrapping is supported in [*Pyright*](https://github.com/microsoft/pyright) via `dataclass_transform` / {pep}`689`.
Generic decorator wrapping is supported in [*Pyright*](https://github.com/microsoft/pyright) via `typing.dataclass_transform` / {pep}`689`.

For a custom wrapping of the form:

```
@typing.dataclass_transform(field_specifiers=(attr.attrib, attr.field))
def custom_define(f):
return attr.define(f)
```

This is implemented via a `__dataclass_transform__` type decorator in the custom extension's `.pyi` of the form:

```
def __dataclass_transform__(
*,
eq_default: bool = True,
order_default: bool = False,
kw_only_default: bool = False,
field_descriptors: Tuple[Union[type, Callable[..., Any]], ...] = (()),
) -> Callable[[_T], _T]: ...
@__dataclass_transform__(field_descriptors=(attr.attrib, attr.field))
def custom_define(f): ...
```


## Types

*attrs* offers two ways of attaching type information to attributes:
Expand Down
38 changes: 11 additions & 27 deletions src/attr/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ if sys.version_info >= (3, 10):
else:
from typing_extensions import TypeGuard

if sys.version_info >= (3, 11):
from typing import dataclass_transform
else:
from typing_extensions import dataclass_transform

__version__: str
__version_info__: VersionInfo
__title__: str
Expand Down Expand Up @@ -103,23 +108,6 @@ else:
takes_self: bool = ...,
) -> _T: ...

# Static type inference support via __dataclass_transform__ implemented as per:
# https://github.com/microsoft/pyright/blob/1.1.135/specs/dataclass_transforms.md
# This annotation must be applied to all overloads of "define" and "attrs"
#
# NOTE: This is a typing construct and does not exist at runtime. Extensions
# wrapping attrs decorators should declare a separate __dataclass_transform__
# signature in the extension module using the specification linked above to
# provide pyright support.
def __dataclass_transform__(
*,
eq_default: bool = True,
order_default: bool = False,
kw_only_default: bool = False,
frozen_default: bool = False,
field_descriptors: Tuple[Union[type, Callable[..., Any]], ...] = (()),
) -> Callable[[_T], _T]: ...

class Attribute(Generic[_T]):
name: str
default: Optional[_T]
Expand Down Expand Up @@ -322,7 +310,7 @@ def field(
type: Optional[type] = ...,
) -> Any: ...
@overload
@__dataclass_transform__(order_default=True, field_descriptors=(attrib, field))
@dataclass_transform(order_default=True, field_specifiers=(attrib, field))
def attrs(
maybe_cls: _C,
these: Optional[Dict[str, Any]] = ...,
Expand Down Expand Up @@ -350,7 +338,7 @@ def attrs(
unsafe_hash: Optional[bool] = ...,
) -> _C: ...
@overload
@__dataclass_transform__(order_default=True, field_descriptors=(attrib, field))
@dataclass_transform(order_default=True, field_specifiers=(attrib, field))
def attrs(
maybe_cls: None = ...,
these: Optional[Dict[str, Any]] = ...,
Expand Down Expand Up @@ -378,7 +366,7 @@ def attrs(
unsafe_hash: Optional[bool] = ...,
) -> Callable[[_C], _C]: ...
@overload
@__dataclass_transform__(field_descriptors=(attrib, field))
@dataclass_transform(field_specifiers=(attrib, field))
def define(
maybe_cls: _C,
*,
Expand All @@ -404,7 +392,7 @@ def define(
match_args: bool = ...,
) -> _C: ...
@overload
@__dataclass_transform__(field_descriptors=(attrib, field))
@dataclass_transform(field_specifiers=(attrib, field))
def define(
maybe_cls: None = ...,
*,
Expand Down Expand Up @@ -433,9 +421,7 @@ def define(
mutable = define

@overload
@__dataclass_transform__(
frozen_default=True, field_descriptors=(attrib, field)
)
@dataclass_transform(frozen_default=True, field_specifiers=(attrib, field))
def frozen(
maybe_cls: _C,
*,
Expand All @@ -461,9 +447,7 @@ def frozen(
match_args: bool = ...,
) -> _C: ...
@overload
@__dataclass_transform__(
frozen_default=True, field_descriptors=(attrib, field)
)
@dataclass_transform(frozen_default=True, field_specifiers=(attrib, field))
def frozen(
maybe_cls: None = ...,
*,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_pyright.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def parse_pyright_output(test_file: Path) -> set[PyrightDiagnostic]:

def test_pyright_baseline():
"""
The __dataclass_transform__ decorator allows pyright to determine attrs
decorated class types.
The typing.dataclass_transform decorator allows pyright to determine
attrs decorated class types.
"""

test_file = Path(__file__).parent / "dataclass_transform_example.py"
Expand Down

0 comments on commit 36d4762

Please sign in to comment.