Skip to content

Commit

Permalink
Use xfail instead of silently passing tests
Browse files Browse the repository at this point in the history
This makes it more clear when certain features are not supported by an
algebra.

To be able to use xfail, this adds a mechanism to declare what sidedness
values are supported by specfic algebra operations. This mechanism is
a decorator which adds an attribute with the allowed sidedness values to
the method object.

Closes #281.
  • Loading branch information
jgosmann committed Nov 6, 2021
1 parent 9a65404 commit 4a370e5
Show file tree
Hide file tree
Showing 8 changed files with 285 additions and 230 deletions.
1 change: 1 addition & 0 deletions docs/modules/nengo_spa.algebras.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ The following items are re-exported by :mod:`nengo_spa.algebras`:
base.AbstractAlgebra
base.CommonProperties
base.ElementSidedness
base.supports_sidedness
hrr_algebra.HrrAlgebra
hrr_algebra.HrrProperties
vtb_algebra.VtbAlgebra
Expand Down
44 changes: 44 additions & 0 deletions nengo_spa/algebras/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,50 @@ class ElementSidedness(Enum):
TWO_SIDED = "two-sided"


def supports_sidedness(sidedness):
"""Declare supported sidedness values on an operation.
This decorator can be used with methods in an algebra that take a
*sidedness* parameter. It declares which values of *sidedness* are
supported by the algebra. The valid values are added as a *frozenset* as
a *supported_sidedness* attribute on the method.
When checking for supported sidedness, it must first be checked whether
the *supported_sidedness* attribute exists (for backwards compatibility).
If it does not exist, it should be assumed that all values for *sidedness*
are supported.
Parameters
----------
sidedness: Iterable[ElementSidedness]
The sidedness values that are supported by the annotated method.
Returns
-------
function
The method itself with the *supported_sidedness* attribute added.
Examples
-------
>>> class MyAlgebra(AbstractAlgebra):
>>> @supports_sidedness({ElementSidedness.LEFT})
>>> def invert(self, v, sidedness):
>>> # ...
>>>
>>> # ...
>>>
>>> print(MyAlgebra.invert.supported_sidedness)
frozenset({<ElementSidedness.LEFT: 'left'>})
"""

def decorator(fn):
setattr(fn, "supported_sidedness", frozenset(sidedness))
return fn

return decorator


class _DuckTypedABCMeta(ABCMeta):
def __instancecheck__(cls, instance):
if super().__instancecheck__(instance):
Expand Down
274 changes: 135 additions & 139 deletions nengo_spa/algebras/tests/test_algebras.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from nengo_spa.algebras.base import AbstractAlgebra, CommonProperties, ElementSidedness
from nengo_spa.algebras.vtb_algebra import VtbAlgebra
from nengo_spa.conftest import check_sidedness
from nengo_spa.vector_generation import UnitLengthVectors


Expand Down Expand Up @@ -35,39 +36,39 @@ def test_superpose(algebra, rng):
@pytest.mark.parametrize("d", [25, 36])
@pytest.mark.parametrize("sidedness", ElementSidedness)
def test_binding_and_invert(algebra, d, sidedness, rng):
check_sidedness(algebra, "invert", sidedness)

dissimilarity_passed = 0
unbinding_passed = 0
try:
for i in range(10):
gen = UnitLengthVectors(d, rng=rng)
a = next(gen)
b = next(gen)

binding_side = sidedness
if sidedness is ElementSidedness.TWO_SIDED:
binding_side = (
ElementSidedness.LEFT if i % 1 == 0 else ElementSidedness.RIGHT
)

with warnings.catch_warnings():
warnings.simplefilter("error", DeprecationWarning)
if binding_side is ElementSidedness.LEFT:
bound = algebra.bind(b, a)
r = algebra.bind(algebra.invert(b, sidedness=sidedness), bound)
elif binding_side is ElementSidedness.RIGHT:
bound = algebra.bind(a, b)
r = algebra.bind(bound, algebra.invert(b, sidedness=sidedness))
else:
raise AssertionError("Invalid binding_side value.")

for v in (a, b):
dissimilarity_passed += np.dot(v, bound / np.linalg.norm(b)) < 0.7
unbinding_passed += np.dot(a, r / np.linalg.norm(r)) > 0.6

assert dissimilarity_passed >= 2 * 8
assert unbinding_passed >= 7
except (NotImplementedError, DeprecationWarning):
pass

for i in range(10):
gen = UnitLengthVectors(d, rng=rng)
a = next(gen)
b = next(gen)

binding_side = sidedness
if sidedness is ElementSidedness.TWO_SIDED:
binding_side = (
ElementSidedness.LEFT if i % 1 == 0 else ElementSidedness.RIGHT
)

with warnings.catch_warnings():
warnings.simplefilter("error", DeprecationWarning)
if binding_side is ElementSidedness.LEFT:
bound = algebra.bind(b, a)
r = algebra.bind(algebra.invert(b, sidedness=sidedness), bound)
elif binding_side is ElementSidedness.RIGHT:
bound = algebra.bind(a, b)
r = algebra.bind(bound, algebra.invert(b, sidedness=sidedness))
else:
raise AssertionError("Invalid binding_side value.")

for v in (a, b):
dissimilarity_passed += np.dot(v, bound / np.linalg.norm(b)) < 0.7
unbinding_passed += np.dot(a, r / np.linalg.norm(r)) > 0.6

assert dissimilarity_passed >= 2 * 8
assert unbinding_passed >= 7


@pytest.mark.parametrize("d", [25, 36])
Expand All @@ -86,19 +87,18 @@ def test_integer_binding_power(algebra, d, rng):

@pytest.mark.parametrize("d", [25, 36])
def test_integer_binding_is_consistent_with_base_implementation(algebra, d, rng):
check_sidedness(algebra, "identity_element", ElementSidedness.LEFT)

v = algebra.create_vector(d, set(), rng=rng)

try:
for exponent in range(-2, 4):
assert np.allclose(
algebra.binding_power(v, exponent),
AbstractAlgebra.binding_power(algebra, v, exponent),
)
for exponent in range(-2, 4):
assert np.allclose(
algebra.binding_power(v, exponent),
AbstractAlgebra.binding_power(algebra, v, exponent),
)

with pytest.raises(ValueError, match="only supports integer binding powers"):
AbstractAlgebra.binding_power(algebra, v, 0.5)
except NotImplementedError:
pytest.skip()
with pytest.raises(ValueError, match="only supports integer binding powers"):
AbstractAlgebra.binding_power(algebra, v, 0.5)


@pytest.mark.parametrize("d", [16, 25])
Expand Down Expand Up @@ -139,110 +139,109 @@ def test_get_binding_matrix(algebra, rng):
@pytest.mark.filterwarnings("ignore:.*sidedness:DeprecationWarning")
@pytest.mark.parametrize("sidedness", ElementSidedness)
def test_get_inversion_matrix(algebra, sidedness, rng):
check_sidedness(algebra, "invert", sidedness)
a = next(UnitLengthVectors(16, rng=rng))
try:
m = algebra.get_inversion_matrix(16, sidedness=sidedness)
assert np.allclose(algebra.invert(a, sidedness=sidedness), np.dot(m, a))
except NotImplementedError:
pass
m = algebra.get_inversion_matrix(16, sidedness=sidedness)
assert np.allclose(algebra.invert(a, sidedness=sidedness), np.dot(m, a))


@pytest.mark.parametrize("sidedness", ElementSidedness)
def test_absorbing_element(algebra, sidedness, rng):
check_sidedness(algebra, "absorbing_element", sidedness)

a = next(UnitLengthVectors(16, rng=rng))
try:
with warnings.catch_warnings(record=True) as caught_warnings:
warnings.simplefilter("always", DeprecationWarning)
p = algebra.absorbing_element(16)
except NotImplementedError:
pass
else:
is_deprecated = len(caught_warnings) > 0 and any(
issubclass(w.category, DeprecationWarning) for w in caught_warnings
)
if (
sidedness in (ElementSidedness.LEFT, ElementSidedness.TWO_SIDED)
and not is_deprecated
):
r = algebra.bind(p, a)
r /= np.linalg.norm(r)
assert np.allclose(p, r) or np.allclose(p, -r)
if sidedness in (ElementSidedness.RIGHT, ElementSidedness.TWO_SIDED):
r = algebra.bind(a, p)
r /= np.linalg.norm(r)
assert np.allclose(p, r) or np.allclose(p, -r)

with warnings.catch_warnings(record=True) as caught_warnings:
warnings.simplefilter("always", DeprecationWarning)
p = algebra.absorbing_element(16)

is_deprecated = len(caught_warnings) > 0 and any(
issubclass(w.category, DeprecationWarning) for w in caught_warnings
)
if (
sidedness in (ElementSidedness.LEFT, ElementSidedness.TWO_SIDED)
and not is_deprecated
):
r = algebra.bind(p, a)
r /= np.linalg.norm(r)
assert np.allclose(p, r) or np.allclose(p, -r)
if sidedness in (ElementSidedness.RIGHT, ElementSidedness.TWO_SIDED):
r = algebra.bind(a, p)
r /= np.linalg.norm(r)
assert np.allclose(p, r) or np.allclose(p, -r)


@pytest.mark.parametrize("sidedness", ElementSidedness)
def test_identity_element(algebra, sidedness, rng):
check_sidedness(algebra, "identity_element", sidedness)

a = next(UnitLengthVectors(16, rng=rng))
try:
with warnings.catch_warnings(record=True) as caught_warnings:
warnings.simplefilter("always", DeprecationWarning)
p = algebra.identity_element(16)
except NotImplementedError:
pass
else:
is_deprecated = len(caught_warnings) > 0 and any(
issubclass(w.category, DeprecationWarning) for w in caught_warnings
)
if (
sidedness in (ElementSidedness.LEFT, ElementSidedness.TWO_SIDED)
and not is_deprecated
):
assert np.allclose(algebra.bind(p, a), a)
if sidedness in (ElementSidedness.RIGHT, ElementSidedness.TWO_SIDED):
assert np.allclose(algebra.bind(a, p), a)

with warnings.catch_warnings(record=True) as caught_warnings:
warnings.simplefilter("always", DeprecationWarning)
p = algebra.identity_element(16)

is_deprecated = len(caught_warnings) > 0 and any(
issubclass(w.category, DeprecationWarning) for w in caught_warnings
)
if (
sidedness in (ElementSidedness.LEFT, ElementSidedness.TWO_SIDED)
and not is_deprecated
):
assert np.allclose(algebra.bind(p, a), a)
if sidedness in (ElementSidedness.RIGHT, ElementSidedness.TWO_SIDED):
assert np.allclose(algebra.bind(a, p), a)


@pytest.mark.parametrize("sidedness", ElementSidedness)
def test_negative_identity_element(algebra, sidedness, rng):
try:
x = next(UnitLengthVectors(16, rng=rng))
a = algebra.abs(x)
with warnings.catch_warnings(record=True) as caught_warnings:
warnings.simplefilter("always", DeprecationWarning)
p = algebra.negative_identity_element(16)
except NotImplementedError:
pass
else:
is_deprecated = len(caught_warnings) > 0 and any(
issubclass(w.category, DeprecationWarning) for w in caught_warnings
)
if (
sidedness in (ElementSidedness.LEFT, ElementSidedness.TWO_SIDED)
and not is_deprecated
):
b = algebra.bind(p, a)
assert np.allclose(algebra.abs(b), a)
assert algebra.sign(b).is_negative()
if sidedness in (ElementSidedness.RIGHT, ElementSidedness.TWO_SIDED):
b = algebra.bind(a, p)
assert np.allclose(algebra.abs(b), a)
assert algebra.sign(b).is_negative()
x = next(UnitLengthVectors(16, rng=rng))
if algebra.sign(x).is_indefinite():
pytest.xfail("Generated vector has an indefinite sign.")

a = algebra.abs(x)

with warnings.catch_warnings(record=True) as caught_warnings:
warnings.simplefilter("always", DeprecationWarning)
p = algebra.negative_identity_element(16)

is_deprecated = len(caught_warnings) > 0 and any(
issubclass(w.category, DeprecationWarning) for w in caught_warnings
)
if (
sidedness in (ElementSidedness.LEFT, ElementSidedness.TWO_SIDED)
and not is_deprecated
):
b = algebra.bind(p, a)
assert np.allclose(algebra.abs(b), a)
assert algebra.sign(b).is_negative()
if sidedness in (ElementSidedness.RIGHT, ElementSidedness.TWO_SIDED):
b = algebra.bind(a, p)
assert np.allclose(algebra.abs(b), a)
assert algebra.sign(b).is_negative()


@pytest.mark.parametrize("sidedness", ElementSidedness)
def test_zero_element(algebra, sidedness, rng):
check_sidedness(algebra, "zero_element", sidedness)

a = next(UnitLengthVectors(16, rng=rng))
try:
with warnings.catch_warnings(record=True) as caught_warnings:
warnings.simplefilter("always", DeprecationWarning)
p = algebra.zero_element(16)
except NotImplementedError:
pass
else:
assert np.all(p == 0.0)
is_deprecated = len(caught_warnings) > 0 and any(
issubclass(w.category, DeprecationWarning) for w in caught_warnings
)
if (
sidedness in (ElementSidedness.LEFT, ElementSidedness.TWO_SIDED)
and not is_deprecated
):
assert np.allclose(algebra.bind(a, p), 0.0)
if sidedness in (ElementSidedness.RIGHT, ElementSidedness.TWO_SIDED):
assert np.allclose(algebra.bind(p, a), 0.0)

with warnings.catch_warnings(record=True) as caught_warnings:
warnings.simplefilter("always", DeprecationWarning)
p = algebra.zero_element(16)

assert np.all(p == 0.0)
is_deprecated = len(caught_warnings) > 0 and any(
issubclass(w.category, DeprecationWarning) for w in caught_warnings
)
if (
sidedness in (ElementSidedness.LEFT, ElementSidedness.TWO_SIDED)
and not is_deprecated
):
assert np.allclose(algebra.bind(a, p), 0.0)
if sidedness in (ElementSidedness.RIGHT, ElementSidedness.TWO_SIDED):
assert np.allclose(algebra.bind(p, a), 0.0)


def test_isinstance_check(algebra):
Expand Down Expand Up @@ -318,22 +317,19 @@ def test_isinstance_ducktyping_check():
@pytest.mark.parametrize("sidedness", ElementSidedness)
@pytest.mark.filterwarnings("ignore:.*sidedness:DeprecationWarning")
def test_sign(algebra, element, check_property, sidedness):
try:
v = getattr(algebra, element + "_element")(16, sidedness)
assert getattr(algebra.sign(v), check_property)()
except NotImplementedError:
pass
method_name = f"{element}_element"
check_sidedness(algebra, method_name, sidedness)
v = getattr(algebra, method_name)(16, sidedness)
assert getattr(algebra.sign(v), check_property)()


@pytest.mark.parametrize("d", [16, 25])
@pytest.mark.parametrize("sidedness", ElementSidedness)
def test_abs(algebra, d, sidedness):
try:
neg_v = algebra.negative_identity_element(d, sidedness)
assert algebra.sign(neg_v).is_negative()
v = algebra.abs(neg_v)
assert algebra.sign(v).is_positive()
assert np.allclose(v, algebra.identity_element(d, sidedness))
assert np.allclose(algebra.abs(v), v) # idempotency
except NotImplementedError:
pass
check_sidedness(algebra, "negative_identity_element", sidedness)
neg_v = algebra.negative_identity_element(d, sidedness)
assert algebra.sign(neg_v).is_negative()
v = algebra.abs(neg_v)
assert algebra.sign(v).is_positive()
assert np.allclose(v, algebra.identity_element(d, sidedness))
assert np.allclose(algebra.abs(v), v) # idempotency
Loading

0 comments on commit 4a370e5

Please sign in to comment.