Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Test for read-only arrays #205

Merged
merged 2 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions array_api_compat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""
NumPy Array API compatibility library

This is a small wrapper around NumPy and CuPy that is compatible with the
Array API standard https://data-apis.org/array-api/latest/. See also NEP 47
https://numpy.org/neps/nep-0047-array-api-standard.html.
This is a small wrapper around NumPy, CuPy, JAX, sparse and others that are
compatible with the Array API standard https://data-apis.org/array-api/latest/.
See also NEP 47 https://numpy.org/neps/nep-0047-array-api-standard.html.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(unrelated to this PR, but probably no need to link to that NEP now)


Unlike array_api_strict, this is not a strict minimal implementation of the
Array API, but rather just an extension of the main NumPy namespace with
Expand Down
19 changes: 19 additions & 0 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,7 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
return x
return x.to_device(device, stream=stream)


def size(x):
"""
Return the total number of elements of x.
Expand All @@ -801,6 +802,23 @@ def size(x):
return None
return math.prod(x.shape)


def is_writeable_array(x) -> bool:
"""
Return False if ``x.__setitem__`` is expected to raise; True otherwise.

Warning
-------
As there is no standard way to check if an array is writeable without actually
writing to it, this function blindly returns True for all unknown array types.
"""
if is_numpy_array(x):
return x.flags.writeable
if is_jax_array(x) or is_pydata_sparse_array(x):
return False
return True


__all__ = [
"array_namespace",
"device",
Expand All @@ -821,6 +839,7 @@ def size(x):
"is_ndonnx_namespace",
"is_pydata_sparse_array",
"is_pydata_sparse_namespace",
"is_writeable_array",
"size",
"to_device",
]
Expand Down
1 change: 1 addition & 0 deletions docs/helper-functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ yet.
.. autofunction:: is_jax_array
.. autofunction:: is_pydata_sparse_array
.. autofunction:: is_ndonnx_array
.. autofunction:: is_writeable_array
.. autofunction:: is_numpy_namespace
.. autofunction:: is_cupy_namespace
.. autofunction:: is_torch_namespace
Expand Down
20 changes: 19 additions & 1 deletion tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace,
)

from array_api_compat import is_array_api_obj, device, to_device
from array_api_compat import device, is_array_api_obj, is_writeable_array, to_device

from ._helpers import import_, wrapped_libraries, all_libraries

Expand Down Expand Up @@ -74,6 +74,24 @@ def test_xp_is_array_generics(library):
assert matches in ([library], ["numpy"])


@pytest.mark.parametrize("library", all_libraries)
def test_is_writeable_array(library):
lib = import_(library)
x = lib.asarray([1, 2, 3])
if is_writeable_array(x):
x[1] = 4
else:
with pytest.raises((TypeError, ValueError)):
x[1] = 4


def test_is_writeable_array_numpy():
x = np.asarray([1, 2, 3])
assert is_writeable_array(x)
x.flags.writeable = False
assert not is_writeable_array(x)


@pytest.mark.parametrize("library", all_libraries)
def test_device(library):
xp = import_(library, wrapper=True)
Expand Down
Loading