Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass args from __init__ to __attrs_pre_init__ if requested #1187

Merged
merged 5 commits into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions changelog.d/1187.change.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
If *attrs* detects that `__attrs_pre_init__` accepts more than just `self`, it will call it with the same arguments as `__init__` was called.
This allows you to, for example, pass arguments to `super().__init__()`.
3 changes: 2 additions & 1 deletion docs/init.md
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,8 @@ However, sometimes you need to do that one quick thing before or after your clas
For that purpose, *attrs* offers the following options:

- `__attrs_pre_init__` is automatically detected and run *before* *attrs* starts initializing.
This is useful if you need to inject a call to `super().__init__()`.
If `__attrs_pre_init__` takes more than the `self` argument, the *attrs*-generated `__init__` will call it with the same arguments it received itself.
This is useful if you need to inject a call to `super().__init__()` -- with or without arguments.

- `__attrs_post_init__` is automatically detected and run *after* *attrs* is done initializing your instance.
This is useful if you want to derive some attribute from others or perform some kind of validation over the whole instance.
Expand Down
26 changes: 26 additions & 0 deletions src/attr/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import contextlib
import copy
import enum
import inspect
import linecache
import sys
import types
Expand Down Expand Up @@ -624,6 +625,7 @@ class _ClassBuilder:
"_delete_attribs",
"_frozen",
"_has_pre_init",
"_pre_init_has_args",
"_has_post_init",
"_is_exc",
"_on_setattr",
Expand Down Expand Up @@ -670,6 +672,13 @@ def __init__(
self._weakref_slot = weakref_slot
self._cache_hash = cache_hash
self._has_pre_init = bool(getattr(cls, "__attrs_pre_init__", False))
self._pre_init_has_args = False
if self._has_pre_init:
# Check if the pre init method has more arguments than just `self`
# We want to pass arguments if pre init expects arguments
pre_init_func = cls.__attrs_pre_init__
pre_init_signature = inspect.signature(pre_init_func)
self._pre_init_has_args = len(pre_init_signature.parameters) > 1
self._has_post_init = bool(getattr(cls, "__attrs_post_init__", False))
self._delete_attribs = not bool(these)
self._is_exc = is_exc
Expand Down Expand Up @@ -974,6 +983,7 @@ def add_init(self):
self._cls,
self._attrs,
self._has_pre_init,
self._pre_init_has_args,
self._has_post_init,
self._frozen,
self._slots,
Expand All @@ -1000,6 +1010,7 @@ def add_attrs_init(self):
self._cls,
self._attrs,
self._has_pre_init,
self._pre_init_has_args,
self._has_post_init,
self._frozen,
self._slots,
Expand Down Expand Up @@ -1984,6 +1995,7 @@ def _make_init(
cls,
attrs,
pre_init,
pre_init_has_args,
post_init,
frozen,
slots,
Expand Down Expand Up @@ -2027,6 +2039,7 @@ def _make_init(
frozen,
slots,
pre_init,
pre_init_has_args,
post_init,
cache_hash,
base_attr_map,
Expand Down Expand Up @@ -2107,6 +2120,7 @@ def _attrs_to_init_script(
frozen,
slots,
pre_init,
pre_init_has_args,
post_init,
cache_hash,
base_attr_map,
Expand Down Expand Up @@ -2361,11 +2375,23 @@ def fmt_setter_with_converter(
lines.append(f"BaseException.__init__(self, {vals})")

args = ", ".join(args)
pre_init_args = args
if kw_only_args:
args += "%s*, %s" % (
", " if args else "", # leading comma
", ".join(kw_only_args), # kw_only args
)
pre_init_kw_only_args = ", ".join(
["%s=%s" % (kw_arg, kw_arg) for kw_arg in kw_only_args]
)
pre_init_args += (
", " if pre_init_args else ""
) # handle only kwargs and no regular args
pre_init_args += pre_init_kw_only_args

if pre_init and pre_init_has_args:
# If pre init method has arguments, pass same arguments as `__init__`
lines[0] = "self.__attrs_pre_init__(%s)" % pre_init_args

return (
"def %s(self, %s):\n %s\n"
Expand Down
8 changes: 7 additions & 1 deletion tests/test_dunders.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@


import copy
import inspect
import pickle

import pytest
Expand Down Expand Up @@ -84,10 +85,15 @@ def _add_init(cls, frozen):
This function used to be part of _make. It wasn't used anymore however
the tests for it are still useful to test the behavior of _make_init.
"""
has_pre_init = bool(getattr(cls, "__attrs_pre_init__", False))

cls.__init__ = _make_init(
cls,
cls.__attrs_attrs__,
getattr(cls, "__attrs_pre_init__", False),
has_pre_init,
len(inspect.signature(cls.__attrs_pre_init__).parameters) > 1
if has_pre_init
else False,
getattr(cls, "__attrs_post_init__", False),
frozen,
_is_slot_cls(cls),
Expand Down
74 changes: 71 additions & 3 deletions tests/test_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,21 +613,89 @@ class D:
assert C.D.__qualname__ == C.__qualname__ + ".D"

@pytest.mark.parametrize("with_validation", [True, False])
def test_pre_init(self, with_validation, monkeypatch):
def test_pre_init(self, with_validation):
"""
Verify that __attrs_pre_init__ gets called if defined.
"""
monkeypatch.setattr(_config, "_run_validators", with_validation)

@attr.s
class C:
def __attrs_pre_init__(self2):
self2.z = 30

c = C()
try:
attr.validators.set_disabled(not with_validation)
c = C()
finally:
attr.validators.set_disabled(False)

assert 30 == getattr(c, "z", None)

@pytest.mark.parametrize("with_validation", [True, False])
def test_pre_init_args(self, with_validation):
"""
Verify that __attrs_pre_init__ gets called with extra args if defined.
"""

@attr.s
class C:
x = attr.ib()

def __attrs_pre_init__(self2, x):
self2.z = x + 1

try:
attr.validators.set_disabled(not with_validation)
c = C(x=10)
finally:
attr.validators.set_disabled(False)

assert 11 == getattr(c, "z", None)

@pytest.mark.parametrize("with_validation", [True, False])
def test_pre_init_kwargs(self, with_validation):
"""
Verify that __attrs_pre_init__ gets called with extra args and kwargs if defined.
"""

@attr.s
class C:
x = attr.ib()
y = attr.field(kw_only=True)

def __attrs_pre_init__(self2, x, y):
self2.z = x + y + 1

try:
attr.validators.set_disabled(not with_validation)
c = C(10, y=11)
finally:
attr.validators.set_disabled(False)

assert 22 == getattr(c, "z", None)

@pytest.mark.parametrize("with_validation", [True, False])
def test_pre_init_kwargs_only(self, with_validation):
"""
Verify that __attrs_pre_init__ gets called with extra kwargs only if
defined.
"""

@attr.s
class C:
y = attr.field(kw_only=True)

def __attrs_pre_init__(self2, y):
self2.z = y + 1

try:
attr.validators.set_disabled(not with_validation)
c = C(y=11)
finally:
attr.validators.set_disabled(False)

assert 12 == getattr(c, "z", None)

@pytest.mark.parametrize("with_validation", [True, False])
def test_post_init(self, with_validation, monkeypatch):
"""
Expand Down