Skip to content

Commit

Permalink
New function is_writeable_array
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Dec 12, 2024
1 parent d9df003 commit 564cb39
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 4 deletions.
6 changes: 3 additions & 3 deletions array_api_compat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""
NumPy Array API compatibility library
This is a small wrapper around NumPy and CuPy that is compatible with the
Array API standard https://data-apis.org/array-api/latest/. See also NEP 47
https://numpy.org/neps/nep-0047-array-api-standard.html.
This is a small wrapper around NumPy, CuPy, JAX, sparse and others that is
compatible with the Array API standard https://data-apis.org/array-api/latest/.
See also NEP 47 https://numpy.org/neps/nep-0047-array-api-standard.html.
Unlike array_api_strict, this is not a strict minimal implementation of the
Array API, but rather just an extension of the main NumPy namespace with
Expand Down
14 changes: 14 additions & 0 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,7 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
return x
return x.to_device(device, stream=stream)


def size(x):
"""
Return the total number of elements of x.
Expand All @@ -801,6 +802,18 @@ def size(x):
return None
return math.prod(x.shape)


def is_writeable_array(x) -> bool:
"""
Return False if ``x.__setitem__`` is expected to raise; True otherwise
"""
if is_numpy_array(x):
return x.flags.writeable
if is_jax_array(x) or is_pydata_sparse_array(x):
return False
return True


__all__ = [
"array_namespace",
"device",
Expand All @@ -821,6 +834,7 @@ def size(x):
"is_ndonnx_namespace",
"is_pydata_sparse_array",
"is_pydata_sparse_namespace",
"is_writeable_array",
"size",
"to_device",
]
Expand Down
1 change: 1 addition & 0 deletions docs/helper-functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ yet.
.. autofunction:: is_jax_array
.. autofunction:: is_pydata_sparse_array
.. autofunction:: is_ndonnx_array
.. autofunction:: is_writeable_array
.. autofunction:: is_numpy_namespace
.. autofunction:: is_cupy_namespace
.. autofunction:: is_torch_namespace
Expand Down
20 changes: 19 additions & 1 deletion tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace,
)

from array_api_compat import is_array_api_obj, device, to_device
from array_api_compat import device, is_array_api_obj, is_writeable_array, to_device

from ._helpers import import_, wrapped_libraries, all_libraries

Expand Down Expand Up @@ -55,6 +55,24 @@ def test_is_xp_namespace(library, func):
assert is_func(lib) == (func == is_namespace_functions[library])


@pytest.mark.parametrize("library", all_libraries)
def test_is_writeable_array(library):
lib = import_(library)
x = lib.asarray([1, 2, 3])
if is_writeable_array(x):
x[1] = 4
else:
with pytest.raises((TypeError, ValueError)):
x[1] = 4


def test_is_writeable_array_numpy():
x = np.asarray([1, 2, 3])
assert is_writeable_array(x)
x.flags.writeable = False
assert not is_writeable_array(x)


@pytest.mark.parametrize("library", all_libraries)
def test_device(library):
xp = import_(library, wrapper=True)
Expand Down

0 comments on commit 564cb39

Please sign in to comment.