Skip to content

Commit

Permalink
Abstractions for read-only arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Nov 27, 2024
1 parent ee25aae commit c8f6613
Show file tree
Hide file tree
Showing 5 changed files with 453 additions and 7 deletions.
6 changes: 3 additions & 3 deletions array_api_compat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""
NumPy Array API compatibility library
This is a small wrapper around NumPy and CuPy that is compatible with the
Array API standard https://data-apis.org/array-api/latest/. See also NEP 47
https://numpy.org/neps/nep-0047-array-api-standard.html.
This is a small wrapper around NumPy, CuPy, JAX and others that is compatible
with the Array API standard https://data-apis.org/array-api/latest/.
See also NEP 47 https://numpy.org/neps/nep-0047-array-api-standard.html.
Unlike array_api_strict, this is not a strict minimal implementation of the
Array API, but rather just an extension of the main NumPy namespace with
Expand Down
266 changes: 263 additions & 3 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
"""
from __future__ import annotations

import operator
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from typing import Optional, Union, Any
from typing import Callable, Literal, Optional, Union, Any
from ._typing import Array, Device

import sys
Expand Down Expand Up @@ -91,7 +92,7 @@ def is_cupy_array(x):
import cupy as cp

# TODO: Should we reject ndarray subclasses?
return isinstance(x, (cp.ndarray, cp.generic))
return isinstance(x, cp.ndarray)

def is_torch_array(x):
"""
Expand Down Expand Up @@ -787,6 +788,7 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
return x
return x.to_device(device, stream=stream)


def size(x):
"""
Return the total number of elements of x.
Expand All @@ -801,6 +803,262 @@ def size(x):
return None
return math.prod(x.shape)


def is_writeable_array(x) -> bool:
"""
Return False if x.__setitem__ is expected to raise; True otherwise
"""
if is_numpy_array(x):
return x.flags.writeable
if is_jax_array(x) or is_pydata_sparse_array(x):
return False
return True


def _is_fancy_index(idx) -> bool:
if not isinstance(idx, tuple):
idx = (idx,)
return any(
isinstance(i, (list, tuple)) or is_array_api_obj(i)
for i in idx
)


_undef = object()


class at:
"""
Update operations for read-only arrays.
This implements ``jax.numpy.ndarray.at`` for all backends.
Writeable arrays may be updated in place; you should not rely on it.
Keyword arguments (e.g. ``indices_are_sorted``) are passed to JAX and are
quietly ignored for backends that don't support them.
Additionally, this introduces support for the `copy` keyword for all backends:
None
x *may* be modified in place if it is possible and beneficial
for performance. You should not use x after calling this function.
True
Ensure that the inputs are not modified. This is the default.
False
Raise ValueError if a copy cannot be avoided.
Examples
--------
Given either of these equivalent expressions::
x = at(x)[1].add(2, copy=None)
x = at(x, 1).add(2, copy=None)
If x is a JAX array, they are the same as::
x = x.at[1].add(2)
If x is a read-only numpy array, they are the same as::
x = x.copy()
x[1] += 2
Otherwise, they are the same as::
x[1] += 2
Warning
-------
When you use copy=None, you should always immediately overwrite
the parameter array::
x = at(x, 0).set(2, copy=None)
The anti-pattern below must be avoided, as it will result in different behaviour
on read-only versus writeable arrays:
x = xp.asarray([0, 0, 0])
y = at(x, 0).set(2, copy=None)
z = at(x, 1).set(3, copy=None)
In the above example, y == [2, 0, 0] and z == [0, 3, 0] when x is read-only,
whereas y == z == [2, 3, 0] when x is writeable!
Caveat
------
The behaviour of methods other than `get()` when the index is an array of
integers which contains multiple occurrences of the same index is undefined.
**Undefined behaviour:** ``at(x, [0, 0]).set(2)``
See Also
--------
https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html
"""

__slots__ = ("x", "idx")

def __init__(self, x, idx=_undef):
self.x = x
self.idx = idx

def __getitem__(self, idx):
"""
Allow for the alternate syntax ``at(x)[start:stop:step]``,
which looks prettier than ``at(x, slice(start, stop, step))``
and feels more intuitive coming from the JAX documentation.
"""
if self.idx is not _undef:
raise ValueError("Index has already been set")
self.idx = idx
return self

def _common(
self,
at_op: str,
y=_undef,
copy: bool | None | Literal["_force_false"] = True,
**kwargs,
):
"""Perform common prepocessing.
Returns
-------
If the operation can be resolved by at[], (return value, None)
Otherwise, (None, preprocessed x)
"""
if self.idx is _undef:
raise TypeError(
"Index has not been set.\n"
"Usage: either\n"
" at(x, idx).set(value)\n"
"or\n"
" at(x)[idx].set(value)\n"
"(same for all other methods)."
)

x = self.x

if copy is False:
if not is_writeable_array(x) or is_dask_array(x):
raise ValueError("Cannot modify parameter in place")
elif copy is None:
copy = not is_writeable_array(x)
elif copy == "_force_false":
copy = False
elif copy is not True:
raise ValueError(f"Invalid value for copy: {copy!r}")

if is_jax_array(x):
# Use JAX's at[]
at_ = x.at[self.idx]
args = (y,) if y is not _undef else ()
return getattr(at_, at_op)(*args, **kwargs), None

# Emulate at[] behaviour for non-JAX arrays
if copy:
# FIXME We blindly expect the output of x.copy() to be always writeable.
# This holds true for read-only numpy arrays, but not necessarily for
# other backends.
xp = array_namespace(x)
x = xp.asarray(x, copy=True)

return None, x

def get(self, copy: bool | None = True, **kwargs):
"""
Return x[idx]. In addition to plain __getitem__, this allows ensuring
that the output is either a copy or a view; it also allows passing
kwargs to the backend.
"""
# __getitem__ with a fancy index always returns a copy.
# Avoid an unnecessary double copy.
# If copy is forced to False, raise.
if _is_fancy_index(self.idx):
if copy is False:
raise TypeError(
"Indexing a numpy array with a fancy index always "
"results in a copy"
)
# Skip copy inside _common, even if array is not writeable
copy = "_force_false" # type: ignore

res, x = self._common("get", copy=copy, **kwargs)
if res is not None:
return res
return x[self.idx]

def set(self, y, /, **kwargs):
"""x[idx] = y"""
res, x = self._common("set", y, **kwargs)
if res is not None:
return res
x[self.idx] = y
return x

def apply(self, ufunc, /, **kwargs):
"""ufunc.at(x, idx)"""
if is_cupy_array(self.x) or is_torch_array(self.x) or is_dask_array(self.x):
# ufunc.at not implemented
return self.set(ufunc(self.x[self.idx]), **kwargs)

res, x = self._common("apply", ufunc, **kwargs)
if res is not None:
return res
ufunc.at(x, self.idx)
return x

def _iop(
self, at_op: str, elwise_op: Callable[[Array, Array], Array], y: Array, **kwargs
):
"""x[idx] += y or equivalent in-place operation on a subset of x
which is the same as saying
x[idx] = x[idx] + y
Note that this is not the same as
operator.iadd(x[idx], y)
Consider for example when x is a numpy array and idx is a fancy index, which
triggers a deep copy on __getitem__.
"""
res, x = self._common(at_op, y, **kwargs)
if res is not None:
return res
x[self.idx] = elwise_op(x[self.idx], y)
return x

def add(self, y, /, **kwargs):
"""x[idx] += y"""
return self._iop("add", operator.add, y, **kwargs)

def subtract(self, y, /, **kwargs):
"""x[idx] -= y"""
return self._iop("subtract", operator.sub, y, **kwargs)

def multiply(self, y, /, **kwargs):
"""x[idx] *= y"""
return self._iop("multiply", operator.mul, y, **kwargs)

def divide(self, y, /, **kwargs):
"""x[idx] /= y"""
return self._iop("divide", operator.truediv, y, **kwargs)

def power(self, y, /, **kwargs):
"""x[idx] **= y"""
return self._iop("power", operator.pow, y, **kwargs)

def min(self, y, /, **kwargs):
"""x[idx] = minimum(x[idx], y)"""
import numpy as np

return self._iop("min", np.minimum, y, **kwargs)

def max(self, y, /, **kwargs):
"""x[idx] = maximum(x[idx], y)"""
import numpy as np

return self._iop("max", np.maximum, y, **kwargs)


__all__ = [
"array_namespace",
"device",
Expand All @@ -821,8 +1079,10 @@ def size(x):
"is_ndonnx_namespace",
"is_pydata_sparse_array",
"is_pydata_sparse_namespace",
"is_writeable_array",
"size",
"to_device",
"at",
]

_all_ignore = ['sys', 'math', 'inspect', 'warnings']
_all_ignore = ['inspect', 'math', 'operator', 'warnings', 'sys']
2 changes: 2 additions & 0 deletions docs/helper-functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ instead, which would be wrapped.
.. autofunction:: device
.. autofunction:: to_device
.. autofunction:: size
.. autofunction:: at

Inspection Helpers
------------------
Expand All @@ -51,6 +52,7 @@ yet.
.. autofunction:: is_jax_array
.. autofunction:: is_pydata_sparse_array
.. autofunction:: is_ndonnx_array
.. autofunction:: is_writeable_array
.. autofunction:: is_numpy_namespace
.. autofunction:: is_cupy_namespace
.. autofunction:: is_torch_namespace
Expand Down
Loading

0 comments on commit c8f6613

Please sign in to comment.