From d8a4fc2b9c1de0cd6e32672c5a862d0c47b7eb4d Mon Sep 17 00:00:00 2001 From: crusaderky Date: Mon, 16 Dec 2024 16:03:19 +0000 Subject: [PATCH] New function is_writeable_array --- array_api_compat/__init__.py | 6 +++--- array_api_compat/common/_helpers.py | 19 +++++++++++++++++++ docs/helper-functions.rst | 1 + tests/test_common.py | 20 +++++++++++++++++++- 4 files changed, 42 insertions(+), 4 deletions(-) diff --git a/array_api_compat/__init__.py b/array_api_compat/__init__.py index 30b1d852..6911bdd2 100644 --- a/array_api_compat/__init__.py +++ b/array_api_compat/__init__.py @@ -1,9 +1,9 @@ """ NumPy Array API compatibility library -This is a small wrapper around NumPy and CuPy that is compatible with the -Array API standard https://data-apis.org/array-api/latest/. See also NEP 47 -https://numpy.org/neps/nep-0047-array-api-standard.html. +This is a small wrapper around NumPy, CuPy, JAX, sparse and others that are +compatible with the Array API standard https://data-apis.org/array-api/latest/. +See also NEP 47 https://numpy.org/neps/nep-0047-array-api-standard.html. Unlike array_api_strict, this is not a strict minimal implementation of the Array API, but rather just an extension of the main NumPy namespace with diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 23a755cd..706821c4 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -787,6 +787,7 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] return x return x.to_device(device, stream=stream) + def size(x): """ Return the total number of elements of x. @@ -801,6 +802,23 @@ def size(x): return None return math.prod(x.shape) + +def is_writeable_array(x) -> bool: + """ + Return False if ``x.__setitem__`` is expected to raise; True otherwise. + + Warning + ------- + As there is no standard way to check if an array is writeable without actually + writing to it, this function blindly returns True for all unknown array types. + """ + if is_numpy_array(x): + return x.flags.writeable + if is_jax_array(x) or is_pydata_sparse_array(x): + return False + return True + + __all__ = [ "array_namespace", "device", @@ -821,6 +839,7 @@ def size(x): "is_ndonnx_namespace", "is_pydata_sparse_array", "is_pydata_sparse_namespace", + "is_writeable_array", "size", "to_device", ] diff --git a/docs/helper-functions.rst b/docs/helper-functions.rst index f44dc070..9d620ceb 100644 --- a/docs/helper-functions.rst +++ b/docs/helper-functions.rst @@ -51,6 +51,7 @@ yet. .. autofunction:: is_jax_array .. autofunction:: is_pydata_sparse_array .. autofunction:: is_ndonnx_array +.. autofunction:: is_writeable_array .. autofunction:: is_numpy_namespace .. autofunction:: is_cupy_namespace .. autofunction:: is_torch_namespace diff --git a/tests/test_common.py b/tests/test_common.py index d86b9f86..7503481e 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -5,7 +5,7 @@ is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace, ) -from array_api_compat import is_array_api_obj, device, to_device +from array_api_compat import device, is_array_api_obj, is_writeable_array, to_device from ._helpers import import_, wrapped_libraries, all_libraries @@ -74,6 +74,24 @@ def test_xp_is_array_generics(library): assert matches in ([library], ["numpy"]) +@pytest.mark.parametrize("library", all_libraries) +def test_is_writeable_array(library): + lib = import_(library) + x = lib.asarray([1, 2, 3]) + if is_writeable_array(x): + x[1] = 4 + else: + with pytest.raises((TypeError, ValueError)): + x[1] = 4 + + +def test_is_writeable_array_numpy(): + x = np.asarray([1, 2, 3]) + assert is_writeable_array(x) + x.flags.writeable = False + assert not is_writeable_array(x) + + @pytest.mark.parametrize("library", all_libraries) def test_device(library): xp = import_(library, wrapper=True)