Skip to content

Commit

Permalink
__attrs_init__() (#731)
Browse files Browse the repository at this point in the history
  • Loading branch information
indigoviolet authored Jan 23, 2021
1 parent 467e28b commit 654aa92
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 14 deletions.
5 changes: 5 additions & 0 deletions changelog.d/731.change.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
``__attrs__init__()`` will now be injected if ``init==False`` or if ``auto_detect=True`` and a user-defined ``__init__()`` exists.

This enables users to do "pre-init" work in their ``__init__()`` (such as ``super().__init__()``).

``__init__()`` can then delegate constructor argument processing to ``__attrs_init__(*args, **kwargs)``.
93 changes: 80 additions & 13 deletions src/attr/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,6 +897,26 @@ def add_init(self):
self._is_exc,
self._on_setattr is not None
and self._on_setattr is not setters.NO_OP,
attrs_init=False,
)
)

return self

def add_attrs_init(self):
self._cls_dict["__attrs_init__"] = self._add_method_dunders(
_make_init(
self._cls,
self._attrs,
self._has_post_init,
self._frozen,
self._slots,
self._cache_hash,
self._base_attr_map,
self._is_exc,
self._on_setattr is not None
and self._on_setattr is not setters.NO_OP,
attrs_init=True,
)
)

Expand Down Expand Up @@ -1160,6 +1180,11 @@ def attrs(
``attrs`` attributes. Leading underscores are stripped for the
argument name. If a ``__attrs_post_init__`` method exists on the
class, it will be called after the class is fully initialized.
If ``init`` is ``False``, an ``__attrs_init__`` method will be
injected instead. This allows you to define a custom ``__init__``
method that can do pre-init work such as ``super().__init__()``,
and then call ``__attrs_init__()`` and ``__attrs_post_init__()``.
:param bool slots: Create a `slotted class <slotted classes>` that's more
memory-efficient. Slotted classes are generally superior to the default
dict classes, but have some gotchas you should know about, so we
Expand Down Expand Up @@ -1299,6 +1324,8 @@ def attrs(
.. versionadded:: 20.1.0 *getstate_setstate*
.. versionadded:: 20.1.0 *on_setattr*
.. versionadded:: 20.3.0 *field_transformer*
.. versionchanged:: 21.1.0
``init=False`` injects ``__attrs_init__``
"""
if auto_detect and PY2:
raise PythonTooOldError(
Expand Down Expand Up @@ -1408,6 +1435,7 @@ def wrap(cls):
):
builder.add_init()
else:
builder.add_attrs_init()
if cache_hash:
raise TypeError(
"Invalid value for cache_hash. To use hash caching,"
Expand Down Expand Up @@ -1872,6 +1900,7 @@ def _make_init(
base_attr_map,
is_exc,
has_global_on_setattr,
attrs_init,
):
if frozen and has_global_on_setattr:
raise ValueError("Frozen classes can't use on_setattr.")
Expand Down Expand Up @@ -1908,6 +1937,7 @@ def _make_init(
is_exc,
needs_cached_setattr,
has_global_on_setattr,
attrs_init,
)
locs = {}
bytecode = compile(script, unique_filename, "exec")
Expand All @@ -1929,10 +1959,10 @@ def _make_init(
unique_filename,
)

__init__ = locs["__init__"]
__init__.__annotations__ = annotations
init = locs["__attrs_init__"] if attrs_init else locs["__init__"]
init.__annotations__ = annotations

return __init__
return init


def _setattr(attr_name, value_var, has_on_setattr):
Expand Down Expand Up @@ -2047,6 +2077,7 @@ def _attrs_to_init_script(
is_exc,
needs_cached_setattr,
has_global_on_setattr,
attrs_init,
):
"""
Return a script of an initializer for *attrs* and a dict of globals.
Expand Down Expand Up @@ -2317,10 +2348,12 @@ def fmt_setter_with_converter(
)
return (
"""\
def __init__(self, {args}):
def {init_name}(self, {args}):
{lines}
""".format(
args=args, lines="\n ".join(lines) if lines else "pass"
init_name=("__attrs_init__" if attrs_init else "__init__"),
args=args,
lines="\n ".join(lines) if lines else "pass",
),
names_for_globals,
annotations,
Expand Down Expand Up @@ -2666,7 +2699,6 @@ def default(self, meth):
_CountingAttr = _add_eq(_add_repr(_CountingAttr))


@attrs(slots=True, init=False, hash=True)
class Factory(object):
"""
Stores a factory callable.
Expand All @@ -2682,8 +2714,7 @@ class Factory(object):
.. versionadded:: 17.1.0 *takes_self*
"""

factory = attrib()
takes_self = attrib()
__slots__ = ("factory", "takes_self")

def __init__(self, factory, takes_self=False):
"""
Expand All @@ -2693,6 +2724,38 @@ def __init__(self, factory, takes_self=False):
self.factory = factory
self.takes_self = takes_self

def __getstate__(self):
"""
Play nice with pickle.
"""
return tuple(getattr(self, name) for name in self.__slots__)

def __setstate__(self, state):
"""
Play nice with pickle.
"""
for name, value in zip(self.__slots__, state):
setattr(self, name, value)


_f = [
Attribute(
name=name,
default=NOTHING,
validator=None,
repr=True,
cmp=None,
eq=True,
order=False,
hash=True,
init=True,
inherited=False,
)
for name in Factory.__slots__
]

Factory = _add_hash(_add_eq(_add_repr(Factory, attrs=_f), attrs=_f), attrs=_f)


def make_class(name, attrs, bases=(object,), **attributes_arguments):
"""
Expand Down Expand Up @@ -2727,11 +2790,15 @@ def make_class(name, attrs, bases=(object,), **attributes_arguments):
raise TypeError("attrs argument must be a dict or a list.")

post_init = cls_dict.pop("__attrs_post_init__", None)
type_ = type(
name,
bases,
{} if post_init is None else {"__attrs_post_init__": post_init},
)
user_init = cls_dict.pop("__init__", None)

body = {}
if post_init is not None:
body["__attrs_post_init__"] = post_init
if user_init is not None:
body["__init__"] = user_init

type_ = type(name, bases, body)
# For pickling to work, the __module__ variable needs to be set to the
# frame where the class is created. Bypass this step in environments where
# sys._getframe is not defined (Jython for example) or sys._getframe is not
Expand Down
10 changes: 10 additions & 0 deletions tests/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,19 +154,29 @@ class HypClass:

cls_dict = dict(zip(attr_names, attrs))
post_init_flag = draw(st.booleans())
init_flag = draw(st.booleans())

if post_init_flag:

def post_init(self):
pass

cls_dict["__attrs_post_init__"] = post_init

if not init_flag:

def init(self, *args, **kwargs):
self.__attrs_init__(*args, **kwargs)

cls_dict["__init__"] = init

return make_class(
"HypClass",
cls_dict,
slots=slots_flag if slots is None else slots,
frozen=frozen_flag if frozen is None else frozen,
weakref_slot=weakref_flag if weakref_slot is None else weakref_slot,
init=init_flag,
)


Expand Down
1 change: 1 addition & 0 deletions tests/test_dunders.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def _add_init(cls, frozen):
base_attr_map={},
is_exc=False,
has_global_on_setattr=False,
attrs_init=False,
)
return cls

Expand Down
9 changes: 8 additions & 1 deletion tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,14 @@ def test_unknown(self, C):
# No generated class will have a four letter attribute.
with pytest.raises(TypeError) as e:
evolve(C(), aaaa=2)
expected = "__init__() got an unexpected keyword argument 'aaaa'"

if hasattr(C, "__attrs_init__"):
expected = (
"__attrs_init__() got an unexpected keyword argument 'aaaa'"
)
else:
expected = "__init__() got an unexpected keyword argument 'aaaa'"

assert (expected,) == e.value.args

def test_validator_failure(self):
Expand Down
14 changes: 14 additions & 0 deletions tests/test_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,19 @@ class C(object):

assert sentinel == getattr(C, method_name)

@pytest.mark.parametrize("init", [True, False])
def test_respects_init_attrs_init(self, init):
"""
If init=False, adds __attrs_init__ to the class.
Otherwise, it does not.
"""

class C(object):
x = attr.ib()

C = attr.s(init=init)(C)
assert hasattr(C, "__attrs_init__") != init

@pytest.mark.skipif(PY2, reason="__qualname__ is PY3-only.")
@given(slots_outer=booleans(), slots_inner=booleans())
def test_repr_qualname(self, slots_outer, slots_inner):
Expand Down Expand Up @@ -1527,6 +1540,7 @@ class C(object):
.add_order()
.add_hash()
.add_init()
.add_attrs_init()
.add_repr("ns")
.add_str()
.build_class()
Expand Down

0 comments on commit 654aa92

Please sign in to comment.