Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

__attrs_init__() #731

Merged
merged 14 commits into from
Jan 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions changelog.d/731.change.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
``__attrs__init__()`` will now be injected if ``init==False`` or if ``auto_detect=True`` and a user-defined ``__init__()`` exists.

This enables users to do "pre-init" work in their ``__init__()`` (such as ``super().__init__()``).

``__init__()`` can then delegate constructor argument processing to ``__attrs_init__(*args, **kwargs)``.
93 changes: 80 additions & 13 deletions src/attr/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,6 +897,26 @@ def add_init(self):
self._is_exc,
self._on_setattr is not None
and self._on_setattr is not setters.NO_OP,
attrs_init=False,
)
)

return self

def add_attrs_init(self):
self._cls_dict["__attrs_init__"] = self._add_method_dunders(
_make_init(
self._cls,
self._attrs,
self._has_post_init,
self._frozen,
self._slots,
self._cache_hash,
self._base_attr_map,
self._is_exc,
self._on_setattr is not None
and self._on_setattr is not setters.NO_OP,
attrs_init=True,
)
)

Expand Down Expand Up @@ -1160,6 +1180,11 @@ def attrs(
``attrs`` attributes. Leading underscores are stripped for the
argument name. If a ``__attrs_post_init__`` method exists on the
class, it will be called after the class is fully initialized.

If ``init`` is ``False``, an ``__attrs_init__`` method will be
injected instead. This allows you to define a custom ``__init__``
method that can do pre-init work such as ``super().__init__()``,
and then call ``__attrs_init__()`` and ``__attrs_post_init__()``.
:param bool slots: Create a `slotted class <slotted classes>` that's more
memory-efficient. Slotted classes are generally superior to the default
dict classes, but have some gotchas you should know about, so we
Expand Down Expand Up @@ -1299,6 +1324,8 @@ def attrs(
.. versionadded:: 20.1.0 *getstate_setstate*
.. versionadded:: 20.1.0 *on_setattr*
.. versionadded:: 20.3.0 *field_transformer*
.. versionchanged:: 21.1.0
``init=False`` injects ``__attrs_init__``
"""
if auto_detect and PY2:
raise PythonTooOldError(
Expand Down Expand Up @@ -1408,6 +1435,7 @@ def wrap(cls):
):
builder.add_init()
else:
builder.add_attrs_init()
if cache_hash:
raise TypeError(
"Invalid value for cache_hash. To use hash caching,"
Expand Down Expand Up @@ -1872,6 +1900,7 @@ def _make_init(
base_attr_map,
is_exc,
has_global_on_setattr,
attrs_init,
):
if frozen and has_global_on_setattr:
raise ValueError("Frozen classes can't use on_setattr.")
Expand Down Expand Up @@ -1908,6 +1937,7 @@ def _make_init(
is_exc,
needs_cached_setattr,
has_global_on_setattr,
attrs_init,
)
locs = {}
bytecode = compile(script, unique_filename, "exec")
Expand All @@ -1929,10 +1959,10 @@ def _make_init(
unique_filename,
)

__init__ = locs["__init__"]
__init__.__annotations__ = annotations
init = locs["__attrs_init__"] if attrs_init else locs["__init__"]
hynek marked this conversation as resolved.
Show resolved Hide resolved
init.__annotations__ = annotations

return __init__
return init


def _setattr(attr_name, value_var, has_on_setattr):
Expand Down Expand Up @@ -2047,6 +2077,7 @@ def _attrs_to_init_script(
is_exc,
needs_cached_setattr,
has_global_on_setattr,
attrs_init,
):
"""
Return a script of an initializer for *attrs* and a dict of globals.
Expand Down Expand Up @@ -2317,10 +2348,12 @@ def fmt_setter_with_converter(
)
return (
"""\
def __init__(self, {args}):
def {init_name}(self, {args}):
{lines}
""".format(
args=args, lines="\n ".join(lines) if lines else "pass"
init_name=("__attrs_init__" if attrs_init else "__init__"),
args=args,
lines="\n ".join(lines) if lines else "pass",
),
names_for_globals,
annotations,
Expand Down Expand Up @@ -2666,7 +2699,6 @@ def default(self, meth):
_CountingAttr = _add_eq(_add_repr(_CountingAttr))


@attrs(slots=True, init=False, hash=True)
class Factory(object):
"""
Stores a factory callable.
Expand All @@ -2682,8 +2714,7 @@ class Factory(object):
.. versionadded:: 17.1.0 *takes_self*
"""

factory = attrib()
takes_self = attrib()
__slots__ = ("factory", "takes_self")

def __init__(self, factory, takes_self=False):
"""
Expand All @@ -2693,6 +2724,38 @@ def __init__(self, factory, takes_self=False):
self.factory = factory
self.takes_self = takes_self

def __getstate__(self):
"""
Play nice with pickle.
"""
return tuple(getattr(self, name) for name in self.__slots__)

def __setstate__(self, state):
"""
Play nice with pickle.
"""
for name, value in zip(self.__slots__, state):
setattr(self, name, value)


_f = [
Attribute(
name=name,
default=NOTHING,
validator=None,
repr=True,
cmp=None,
eq=True,
order=False,
hash=True,
init=True,
inherited=False,
)
for name in Factory.__slots__
]

Factory = _add_hash(_add_eq(_add_repr(Factory, attrs=_f), attrs=_f), attrs=_f)


def make_class(name, attrs, bases=(object,), **attributes_arguments):
"""
Expand Down Expand Up @@ -2727,11 +2790,15 @@ def make_class(name, attrs, bases=(object,), **attributes_arguments):
raise TypeError("attrs argument must be a dict or a list.")

post_init = cls_dict.pop("__attrs_post_init__", None)
type_ = type(
name,
bases,
{} if post_init is None else {"__attrs_post_init__": post_init},
)
user_init = cls_dict.pop("__init__", None)
hynek marked this conversation as resolved.
Show resolved Hide resolved

body = {}
if post_init is not None:
body["__attrs_post_init__"] = post_init
if user_init is not None:
body["__init__"] = user_init

type_ = type(name, bases, body)
# For pickling to work, the __module__ variable needs to be set to the
# frame where the class is created. Bypass this step in environments where
# sys._getframe is not defined (Jython for example) or sys._getframe is not
Expand Down
10 changes: 10 additions & 0 deletions tests/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,19 +154,29 @@ class HypClass:

cls_dict = dict(zip(attr_names, attrs))
post_init_flag = draw(st.booleans())
init_flag = draw(st.booleans())

if post_init_flag:

def post_init(self):
pass

cls_dict["__attrs_post_init__"] = post_init

if not init_flag:

def init(self, *args, **kwargs):
self.__attrs_init__(*args, **kwargs)

cls_dict["__init__"] = init

return make_class(
"HypClass",
cls_dict,
slots=slots_flag if slots is None else slots,
frozen=frozen_flag if frozen is None else frozen,
weakref_slot=weakref_flag if weakref_slot is None else weakref_slot,
init=init_flag,
)


Expand Down
1 change: 1 addition & 0 deletions tests/test_dunders.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def _add_init(cls, frozen):
base_attr_map={},
is_exc=False,
has_global_on_setattr=False,
attrs_init=False,
)
return cls

Expand Down
9 changes: 8 additions & 1 deletion tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,14 @@ def test_unknown(self, C):
# No generated class will have a four letter attribute.
with pytest.raises(TypeError) as e:
evolve(C(), aaaa=2)
expected = "__init__() got an unexpected keyword argument 'aaaa'"

if hasattr(C, "__attrs_init__"):
expected = (
"__attrs_init__() got an unexpected keyword argument 'aaaa'"
)
else:
expected = "__init__() got an unexpected keyword argument 'aaaa'"

assert (expected,) == e.value.args

def test_validator_failure(self):
Expand Down
14 changes: 14 additions & 0 deletions tests/test_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,19 @@ class C(object):

assert sentinel == getattr(C, method_name)

@pytest.mark.parametrize("init", [True, False])
def test_respects_init_attrs_init(self, init):
"""
If init=False, adds __attrs_init__ to the class.
Otherwise, it does not.
"""

class C(object):
x = attr.ib()

C = attr.s(init=init)(C)
assert hasattr(C, "__attrs_init__") != init

@pytest.mark.skipif(PY2, reason="__qualname__ is PY3-only.")
@given(slots_outer=booleans(), slots_inner=booleans())
def test_repr_qualname(self, slots_outer, slots_inner):
Expand Down Expand Up @@ -1527,6 +1540,7 @@ class C(object):
.add_order()
.add_hash()
.add_init()
.add_attrs_init()
.add_repr("ns")
.add_str()
.build_class()
Expand Down