diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 0de639fc..3337b4a2 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -91,7 +91,7 @@ def is_cupy_array(x): import cupy as cp # TODO: Should we reject ndarray subclasses? - return isinstance(x, (cp.ndarray, cp.generic)) + return isinstance(x, cp.ndarray) def is_torch_array(x): """ diff --git a/tests/test_common.py b/tests/test_common.py index 7b8a3a43..7503481e 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -55,6 +55,25 @@ def test_is_xp_namespace(library, func): assert is_func(lib) == (func == is_namespace_functions[library]) +@pytest.mark.parametrize('library', all_libraries) +def test_xp_is_array_generics(library): + """ + Test that scalar selection on a xp.ndarray always returns + an object that matches with exactly one among the is_*_array + function of the same library and is_numpy_array. + """ + lib = import_(library) + x = lib.asarray([1, 2, 3]) + x0 = x[0] + + matches = [] + for library2, func in is_array_functions.items(): + is_func = globals()[func] + if is_func(x0): + matches.append(library2) + assert matches in ([library], ["numpy"]) + + @pytest.mark.parametrize("library", all_libraries) def test_is_writeable_array(library): lib = import_(library)