From 564cb39ddea835921dc93d25fb4a2f17a1a7864a Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 12 Dec 2024 15:50:15 +0000 Subject: [PATCH] New function is_writeable_array --- array_api_compat/__init__.py | 6 +++--- array_api_compat/common/_helpers.py | 14 ++++++++++++++ docs/helper-functions.rst | 1 + tests/test_common.py | 20 +++++++++++++++++++- 4 files changed, 37 insertions(+), 4 deletions(-) diff --git a/array_api_compat/__init__.py b/array_api_compat/__init__.py index 30b1d852..754252e1 100644 --- a/array_api_compat/__init__.py +++ b/array_api_compat/__init__.py @@ -1,9 +1,9 @@ """ NumPy Array API compatibility library -This is a small wrapper around NumPy and CuPy that is compatible with the -Array API standard https://data-apis.org/array-api/latest/. See also NEP 47 -https://numpy.org/neps/nep-0047-array-api-standard.html. +This is a small wrapper around NumPy, CuPy, JAX, sparse and others that is +compatible with the Array API standard https://data-apis.org/array-api/latest/. +See also NEP 47 https://numpy.org/neps/nep-0047-array-api-standard.html. Unlike array_api_strict, this is not a strict minimal implementation of the Array API, but rather just an extension of the main NumPy namespace with diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index b011f08d..0de639fc 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -787,6 +787,7 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] return x return x.to_device(device, stream=stream) + def size(x): """ Return the total number of elements of x. @@ -801,6 +802,18 @@ def size(x): return None return math.prod(x.shape) + +def is_writeable_array(x) -> bool: + """ + Return False if ``x.__setitem__`` is expected to raise; True otherwise + """ + if is_numpy_array(x): + return x.flags.writeable + if is_jax_array(x) or is_pydata_sparse_array(x): + return False + return True + + __all__ = [ "array_namespace", "device", @@ -821,6 +834,7 @@ def size(x): "is_ndonnx_namespace", "is_pydata_sparse_array", "is_pydata_sparse_namespace", + "is_writeable_array", "size", "to_device", ] diff --git a/docs/helper-functions.rst b/docs/helper-functions.rst index f44dc070..9d620ceb 100644 --- a/docs/helper-functions.rst +++ b/docs/helper-functions.rst @@ -51,6 +51,7 @@ yet. .. autofunction:: is_jax_array .. autofunction:: is_pydata_sparse_array .. autofunction:: is_ndonnx_array +.. autofunction:: is_writeable_array .. autofunction:: is_numpy_namespace .. autofunction:: is_cupy_namespace .. autofunction:: is_torch_namespace diff --git a/tests/test_common.py b/tests/test_common.py index 6bf55618..7b8a3a43 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -5,7 +5,7 @@ is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace, ) -from array_api_compat import is_array_api_obj, device, to_device +from array_api_compat import device, is_array_api_obj, is_writeable_array, to_device from ._helpers import import_, wrapped_libraries, all_libraries @@ -55,6 +55,24 @@ def test_is_xp_namespace(library, func): assert is_func(lib) == (func == is_namespace_functions[library]) +@pytest.mark.parametrize("library", all_libraries) +def test_is_writeable_array(library): + lib = import_(library) + x = lib.asarray([1, 2, 3]) + if is_writeable_array(x): + x[1] = 4 + else: + with pytest.raises((TypeError, ValueError)): + x[1] = 4 + + +def test_is_writeable_array_numpy(): + x = np.asarray([1, 2, 3]) + assert is_writeable_array(x) + x.flags.writeable = False + assert not is_writeable_array(x) + + @pytest.mark.parametrize("library", all_libraries) def test_device(library): xp = import_(library, wrapper=True)