Skip to content

Commit

Permalink
Self-review
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Dec 11, 2024
1 parent ca147d5 commit 18096a5
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 28 deletions.
134 changes: 109 additions & 25 deletions src/array_api_extra/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,7 @@ def _common(
xp: ModuleType | None = None,
_is_update: bool = True,
**kwargs: Untyped,
) -> tuple[Untyped, None] | tuple[None, Array]:
) -> tuple[Array, None] | tuple[None, Array]:
"""Perform common prepocessing.
Returns
Expand All @@ -704,16 +704,22 @@ def _common(

x = self.x

if copy not in (True, False, None):
msg = f"copy must be True, False, or None; got {copy!r}" # pyright: ignore[reportUnreachable]
raise ValueError(msg)

if copy is None:
writeable = is_writeable_array(x)
copy = _is_update and not writeable
elif copy:
writeable = None
else:
elif _is_update:
writeable = is_writeable_array(x)
if not writeable:
msg = "Cannot modify parameter in place"
raise ValueError(msg)
else:
writeable = None

if copy:
try:
Expand All @@ -723,10 +729,10 @@ def _common(
# with a copy followed by an update
if xp is None:
xp = array_namespace(x)
# Create writeable copy of read-only numpy array
x = xp.asarray(x, copy=True)
if writeable is False:
# A copy of a read-only numpy array is writeable
# Note: this assumes that a copy of a writeable array is writeable
writeable = None
else:
# Use JAX's at[] or other library that with the same duck-type API
Expand All @@ -743,12 +749,18 @@ def _common(

return None, x

def get(self, **kwargs: Untyped) -> Untyped:
def get(
self,
/,
copy: bool | None = True,
xp: ModuleType | None = None,
**kwargs: Untyped,
) -> Untyped:
"""Return ``x[idx]``. In addition to plain ``__getitem__``, this allows ensuring
that the output is either a copy or a view; it also allows passing
keyword arguments to the backend.
"""
if kwargs.get("copy") is False:
if copy is False:
if is_array_api_obj(self.idx):
# Boolean index. Note that the array API spec
# https://data-apis.org/array-api/latest/API_specification/indexing.html
Expand All @@ -758,19 +770,38 @@ def get(self, **kwargs: Untyped) -> Untyped:
# which can be caught by testing the user code vs. array-api-strict.
msg = "get() with an array index always returns a copy"
raise ValueError(msg)

# Prevent scalar indices together with copy=False.
# Even if some backends may return a scalar view of the original, we chose to be
# strict here beceause some other backends, such as numpy, definitely don't.
tup_idx = self.idx if isinstance(self.idx, tuple) else (self.idx,)
if any(
i is not None and i is not Ellipsis and not isinstance(i, slice)
for i in tup_idx
):
msg = "get() with a scalar index typically returns a copy"
raise ValueError(msg)

if is_dask_array(self.x):
msg = "get() on Dask arrays always returns a copy"
raise ValueError(msg)

res, x = self._common("get", _is_update=False, **kwargs)
res, x = self._common("get", copy=copy, xp=xp, _is_update=False, **kwargs)
if res is not None:
return res
assert x is not None
return x[self.idx]

def set(self, y: Array, /, **kwargs: Untyped) -> Array:
def set(
self,
y: Array,
/,
copy: bool | None = True,
xp: ModuleType | None = None,
**kwargs: Untyped,
) -> Array:
"""Apply ``x[idx] = y`` and return the update array"""
res, x = self._common("set", y, **kwargs)
res, x = self._common("set", y, copy=copy, xp=xp, **kwargs)
if res is not None:
return res
assert x is not None
Expand All @@ -785,6 +816,8 @@ def _iop(
elwise_op: Callable[[Array, Array], Array],
y: Array,
/,
copy: bool | None = True,
xp: ModuleType | None = None,
**kwargs: Untyped,
) -> Array:
"""x[idx] += y or equivalent in-place operation on a subset of x
Expand All @@ -796,41 +829,92 @@ def _iop(
Consider for example when x is a numpy array and idx is a fancy index, which
triggers a deep copy on __getitem__.
"""
res, x = self._common(at_op, y, **kwargs)
res, x = self._common(at_op, y, copy=copy, xp=xp, **kwargs)
if res is not None:
return res
assert x is not None
x[self.idx] = elwise_op(x[self.idx], y)
return x

def add(self, y: Array, /, **kwargs: Untyped) -> Array:
def add(
self,
y: Array,
/,
copy: bool | None = True,
xp: ModuleType | None = None,
**kwargs: Untyped,
) -> Array:
"""Apply ``x[idx] += y`` and return the updated array"""
return self._iop("add", operator.add, y, **kwargs)
return self._iop("add", operator.add, y, copy=copy, xp=xp, **kwargs)

def subtract(self, y: Array, /, **kwargs: Untyped) -> Array:
def subtract(
self,
y: Array,
/,
copy: bool | None = True,
xp: ModuleType | None = None,
**kwargs: Untyped,
) -> Array:
"""Apply ``x[idx] -= y`` and return the updated array"""
return self._iop("subtract", operator.sub, y, **kwargs)
return self._iop("subtract", operator.sub, y, copy=copy, xp=xp, **kwargs)

def multiply(self, y: Array, /, **kwargs: Untyped) -> Array:
def multiply(
self,
y: Array,
/,
copy: bool | None = True,
xp: ModuleType | None = None,
**kwargs: Untyped,
) -> Array:
"""Apply ``x[idx] *= y`` and return the updated array"""
return self._iop("multiply", operator.mul, y, **kwargs)
return self._iop("multiply", operator.mul, y, copy=copy, xp=xp, **kwargs)

def divide(self, y: Array, /, **kwargs: Untyped) -> Array:
def divide(
self,
y: Array,
/,
copy: bool | None = True,
xp: ModuleType | None = None,
**kwargs: Untyped,
) -> Array:
"""Apply ``x[idx] /= y`` and return the updated array"""
return self._iop("divide", operator.truediv, y, **kwargs)
return self._iop("divide", operator.truediv, y, copy=copy, xp=xp, **kwargs)

def power(self, y: Array, /, **kwargs: Untyped) -> Array:
def power(
self,
y: Array,
/,
copy: bool | None = True,
xp: ModuleType | None = None,
**kwargs: Untyped,
) -> Array:
"""Apply ``x[idx] **= y`` and return the updated array"""
return self._iop("power", operator.pow, y, **kwargs)
return self._iop("power", operator.pow, y, copy=copy, xp=xp, **kwargs)

def min(self, y: Array, /, **kwargs: Untyped) -> Array:
def min(
self,
y: Array,
/,
copy: bool | None = True,
xp: ModuleType | None = None,
**kwargs: Untyped,
) -> Array:
"""Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array"""
xp = array_namespace(self.x)
if xp is None:
xp = array_namespace(self.x)
y = xp.asarray(y)
return self._iop("min", xp.minimum, y, **kwargs)
return self._iop("min", xp.minimum, y, copy=copy, xp=xp, **kwargs)

def max(self, y: Array, /, **kwargs: Untyped) -> Array:
def max(
self,
y: Array,
/,
copy: bool | None = True,
xp: ModuleType | None = None,
**kwargs: Untyped,
) -> Array:
"""Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array"""
xp = array_namespace(self.x)
if xp is None:
xp = array_namespace(self.x)
y = xp.asarray(y)
return self._iop("max", xp.maximum, y, **kwargs)
return self._iop("max", xp.maximum, y, copy=copy, xp=xp, **kwargs)
34 changes: 31 additions & 3 deletions tests/test_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from array_api_compat import ( # type: ignore[import-untyped] # pyright: ignore[reportMissingTypeStubs]
array_namespace,
is_dask_array,
is_numpy_array,
is_pydata_sparse_array,
is_writeable_array,
)
Expand Down Expand Up @@ -110,6 +111,14 @@ def test_get(array: Array, copy: bool | None):
return
expect_copy = True

# get(copy=False) on a read-only numpy array returns a read-only view
if is_numpy_array(array) and not copy and not array.flags.writeable:
out = at(array, slice(2)).get(copy=copy)
assert_array_equal(out, [10.0, 20.0])
assert out.base is array
assert not out.flags.writeable
return

with assert_copy(array, expect_copy):
y = at(array, slice(2)).get(copy=copy)
assert isinstance(y, type(array))
Expand All @@ -119,6 +128,18 @@ def test_get(array: Array, copy: bool | None):
y[:] = 40


def test_get_scalar_nocopy(array: Array):
"""get(copy=False) with a scalar index always raises, because some backends
such as numpy and sparse return a np.generic instead of a scalar view
"""
with pytest.raises(ValueError, match="scalar"):
at(array)[0].get(copy=False)
with pytest.raises(ValueError, match="scalar"):
at(array)[(0, )].get(copy=False)
with pytest.raises(ValueError, match="scalar"):
at(array)[..., 0].get(copy=False)


def test_get_bool_indices(array: Array):
"""get() with a boolean array index always returns a copy"""
# sparse violates the array API as it doesn't support
Expand Down Expand Up @@ -146,10 +167,17 @@ def test_get_bool_indices(array: Array):
def test_copy_invalid():
a = np.asarray([1, 2, 3])
with pytest.raises(ValueError, match="copy"):
at(a, 0).set(4, copy="invalid")
at(a, 0).set(4, copy="invalid") # type: ignore[arg-type] # pyright: ignore[reportArgumentType]


def test_xp():
a = np.asarray([1, 2, 3])
b = at(a, 0).set(4, xp=np)
assert_array_equal(b, [4, 2, 3])
at(a, 0).get(xp=np)
at(a, 0).set(4, xp=np)
at(a, 0).add(4, xp=np)
at(a, 0).subtract(4, xp=np)
at(a, 0).multiply(4, xp=np)
at(a, 0).divide(4, xp=np)
at(a, 0).power(4, xp=np)
at(a, 0).min(4, xp=np)
at(a, 0).max(4, xp=np)

0 comments on commit 18096a5

Please sign in to comment.