Skip to content

Commit

Permalink
Correctly resolve mro for __getattr__ in cached properties
Browse files Browse the repository at this point in the history
  • Loading branch information
Danny Cooper committed Nov 13, 2023
1 parent dd13b75 commit d447b1c
Show file tree
Hide file tree
Showing 2 changed files with 220 additions and 13 deletions.
56 changes: 44 additions & 12 deletions src/attr/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,18 +599,46 @@ def _transform_attrs(
return _Attributes((AttrsClass(attrs), base_attrs, base_attr_map))


def _make_cached_property_getattr(cached_properties, original_getattr=None):
def __getattr__(instance, item: str):
func = cached_properties.get(item)
if func is not None:
result = func(instance)
_obj_setattr(instance, item, result)
return result
if original_getattr is not None:
return original_getattr(instance, item)
raise AttributeError(item)
def _make_cached_property_getattr(
cached_properties,
cls,
):
lines = [
# Wrapped to get `__class__` into closure cell for super()
# (It will be replaced with the newly constructed class after construction).
"def wrapper(_cls, cached_properties, _cached_setattr_get):",
" __class__ = _cls",
" def __getattr__(self, item):",
" func = cached_properties.get(item)",
" if func is not None:",
" result = func(self)",
" _setter = _cached_setattr_get(self)",
" _setter(item, result)",
" return result",
" if '__attrs_original_getattr__' in vars(__class__):",
" return __class__.__attrs_original_getattr__(self, item)",
" if hasattr(super(), '__getattr__'):",
" return super().__getattr__(item)",
" original_error = f\"'{self.__class__.__name__}' object has no attribute '{item}'\"",
" raise AttributeError(original_error)",
" return __getattr__",
"__getattr__ = wrapper(_cls, cached_properties, _cached_setattr_get)",
]

return __getattr__
unique_filename = _generate_unique_filename(cls, "getattr")

glob = {
"cached_properties": cached_properties,
"_cached_setattr_get": _obj_setattr.__get__,
"_cls": cls,
}

return _make_method(
"__getattr__",
"\n".join(lines),
unique_filename,
glob,
)


def _frozen_setattrs(self, name, value):
Expand Down Expand Up @@ -898,8 +926,12 @@ def _create_slots_class(self):
if annotation is not inspect.Parameter.empty:
cd["__annotations__"][name] = annotation

original_getattr = cd.get("__getattr__")
if original_getattr is not None:
cd["__attrs_original_getattr__"] = original_getattr

cd["__getattr__"] = _make_cached_property_getattr(
cached_properties, cd.get("__getattr__")
cached_properties, self._cls
)

# We only add the names of attributes that aren't inherited.
Expand Down
177 changes: 176 additions & 1 deletion tests/test_slots.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,23 @@ def f(self):
assert "__dict__" not in dir(A)


@pytest.mark.skipif(not PY_3_8_PLUS, reason="cached_property is 3.8+")
def test_slots_cached_property_works_on_frozen_isntances():
"""
Infers type of cached property.
"""

@attrs.frozen(slots=True)
class A:
x: int

@functools.cached_property
def f(self) -> int:
return self.x

assert A(x=1).f == 1


@pytest.mark.skipif(not PY_3_8_PLUS, reason="cached_property is 3.8+")
def test_slots_cached_property_infers_type():
"""
Expand All @@ -768,10 +785,168 @@ def f(self) -> int:
assert A.__annotations__ == {"x": int, "f": int}


@pytest.mark.skipif(not PY_3_8_PLUS, reason="cached_property is 3.8+")
def test_slots_cached_property_with_empty_getattr_raises_attribute_error_of_requested():
"""
Ensures error information is not lost.
"""

@attr.s(slots=True)
class A:
x = attr.ib()

@functools.cached_property
def f(self):
return self.x

a = A(1)
with pytest.raises(
AttributeError, match="'A' object has no attribute 'z'"
):
a.z


@pytest.mark.skipif(not PY_3_8_PLUS, reason="cached_property is 3.8+")
def test_slots_cached_property_with_getattr_calls_getattr_for_missing_attributes():
"""
Ensure __getattr__ implementation is maintained for non cached_properties.
"""

@attr.s(slots=True)
class A:
x = attr.ib()

@functools.cached_property
def f(self):
return self.x

def __getattr__(self, item):
return item

a = A(1)
assert a.f == 1
assert a.z == "z"


@pytest.mark.skipif(not PY_3_8_PLUS, reason="cached_property is 3.8+")
def test_slots_getattr_in_superclass__is_called_for_missing_attributes_when_cached_property_present():
"""
Ensure __getattr__ implementation is maintained in subclass.
"""

@attr.s(slots=True)
class A:
x = attr.ib()

def __getattr__(self, item):
return item

@attr.s(slots=True)
class B(A):
@functools.cached_property
def f(self):
return self.x

b = B(1)
assert b.f == 1
assert b.z == "z"


@pytest.mark.skipif(not PY_3_8_PLUS, reason="cached_property is 3.8+")
def test_slots_getattr_in_subclass_gets_superclass_cached_property():
"""
Ensure super() in __getattr__ is not broken through cached_property re-write.
"""

@attr.s(slots=True)
class A:
x = attr.ib()

@functools.cached_property
def f(self):
return self.x

def __getattr__(self, item):
return item

@attr.s(slots=True)
class B(A):
@functools.cached_property
def g(self):
return self.x

def __getattr__(self, item):
return super().__getattr__(item)

b = B(1)
assert b.f == 1
assert b.z == "z"


@pytest.mark.skipif(not PY_3_8_PLUS, reason="cached_property is 3.8+")
def test_slots_sub_class_with_independent_cached_properties_both_work():
"""
Subclassing shouldn't break cached properties.
"""

@attr.s(slots=True)
class A:
x = attr.ib()

@functools.cached_property
def f(self):
return self.x

@attr.s(slots=True)
class B(A):
@functools.cached_property
def g(self):
return self.x * 2

assert B(1).f == 1
assert B(1).g == 2


@pytest.mark.skipif(not PY_3_8_PLUS, reason="cached_property is 3.8+")
def test_slots_with_multiple_cached_property_subclasses_works():
"""
Multiple sub-classes shouldn't break cached properties.
"""

@attr.s(slots=True)
class A:
x = attr.ib(kw_only=True)

@functools.cached_property
def f(self):
return self.x

@attr.s(slots=False)
class B:
@functools.cached_property
def g(self):
return self.x * 2

def __getattr__(self, item):
if hasattr(super(), "__getattr__"):
return super().__getattr__(item)
return item

@attr.s(slots=True)
class AB(A, B):
pass

ab = AB(x=1)

assert ab.f == 1
assert ab.g == 2
assert ab.h == "h"


@pytest.mark.skipif(not PY_3_8_PLUS, reason="cached_property is 3.8+")
def test_slots_sub_class_avoids_duplicated_slots():
"""
Duplicating the slots is a wast of memory.
Duplicating the slots is a waste of memory.
"""

@attr.s(slots=True)
Expand Down

0 comments on commit d447b1c

Please sign in to comment.