Skip to content

Commit

Permalink
Abstractions for read-only arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Nov 25, 2024
1 parent ee25aae commit 6884a34
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 3 deletions.
6 changes: 3 additions & 3 deletions array_api_compat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""
NumPy Array API compatibility library
This is a small wrapper around NumPy and CuPy that is compatible with the
Array API standard https://data-apis.org/array-api/latest/. See also NEP 47
https://numpy.org/neps/nep-0047-array-api-standard.html.
This is a small wrapper around NumPy, CuPy, JAX and others that is compatible
with the Array API standard https://data-apis.org/array-api/latest/.
See also NEP 47 https://numpy.org/neps/nep-0047-array-api-standard.html.
Unlike array_api_strict, this is not a strict minimal implementation of the
Array API, but rather just an extension of the main NumPy namespace with
Expand Down
171 changes: 171 additions & 0 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,174 @@ def size(x):
return None
return math.prod(x.shape)

def is_writeable_array(x):
"""
Return False if x.__setitem__ is expected to raise; True otherwise
"""
if is_numpy_array(x):
return x.flags.writeable
if is_jax_array(x):
return False
return True

_undef = object()

def at(x, idx=_undef, /):
"""
Update operations for read-only arrays.
This implements ``jax.numpy.ndarray.at`` for all backends.
Writeable arrays may be updated in place; you should not rely on it.
Keyword arguments (e.g. ``indices_are_sorted``) are passed to JAX and are
quietly ignored for backends that don't support them.
Examples
--------
Given either of these equivalent expressions::
x = at(x)[1].add(2)
x = at(x, 1).add(2)
If x is a JAX array, they are the same as::
x = x.at[1].add(x)
If x is a read-only numpy array, they are the same as::
x = x.copy()
x[1] += 2
Otherwise, they are the same as::
x[1] += 2
Warning
-------
You should always immediately overwrite the parameter array::
x = at(x, 0).set(2)
The anti-pattern below must be avoided, as it will result in different behaviour
on read-only versus writeable arrays:
x = xp.asarray([0, 0, 0])
y = at(x, 0).set(2)
z = at(x, 1).set(3)
In the above example, y == [2, 0, 0] and z == [0, 3, 0] when x is read-only,
whereas y == z == [2, 3, 0] when x is writeable!
See Also
--------
https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html
"""
if is_jax_array(x):
return x.at
if is_numpy_array(x) and not x.flags.writeable:
x = x.copy()
return _InPlaceAt(x, idx)

class _InPlaceAt:
"""Helper of at().
Trivially implement jax.numpy.ndarray.at for other backends.
x is updated in place.
"""
__slots__ = ("x", "idx")

def __init__(self, x, idx=_undef):
self.x = x
self.idx = idx

def __getitem__(self, idx):
"""
Allow for the alternate syntax ``at(x)[start:stop:step]``,
which looks prettier than ``at(x, slice(start, stop, step))``
and feels more intuitive coming from the JAX documentation.
"""
self.idx = idx
return self

def _check_args(self, mode="promise_in_bounds", **kwargs):
if self.idx is _undef:
raise TypeError(
"Index has not been set.\n"
"Usage: either\n"
" at(x, idx).set(value)\n"
"or\n"
" at(x)[idx].set(value)\n"
"(same for all other methods)."
)
if mode != "promise_in_bounds":
xp = array_namespace(self.x)
raise NotImplementedError(
f"mode='{mode}' is not supported for backend {xp.__name__}"
)

def set(self, y, /, **kwargs):
self._check_args(**kwargs)
self.x[self.idx] = y
return self.x

def add(self, y, /, **kwargs):
self._check_args(**kwargs)
self.x[self.idx] += y
return self.x

def subtract(self, y, /, **kwargs):
self._check_args(**kwargs)
self.x[self.idx] -= y
return self.x

def multiply(self, y, /, **kwargs):
self._check_args(**kwargs)
self.x[self.idx] *= y
return self.x

def divide(self, y, /, **kwargs):
self._check_args(**kwargs)
self.x[self.idx] /= y
return self.x

def power(self, y, /, **kwargs):
self._check_args(**kwargs)
self.x[self.idx] **= y
return self.x

def min(self, y, /, **kwargs):
self._check_args(**kwargs)
xp = array_namespace(self.x, y)
self.x[self.idx] = xp.minimum(self.x[self.idx], y)
return self.x

def max(self, y, /, **kwargs):
self._check_args(**kwargs)
xp = array_namespace(self.x, y)
self.x[self.idx] = xp.maximum(self.x[self.idx], y)
return self.x

def apply(self, ufunc, /, **kwargs):
self._check_args(**kwargs)
ufunc.at(self.x, self.idx)
return self.x

def get(self, **kwargs):
self._check_args(**kwargs)
return self.x[self.idx]

def iwhere(condition, x, y, /):
"""Variant of xp.where(condition, x, y) which may or may not update
x in place, if it's possible and beneficial for performance.
"""
xp = array_namespace(condition, x, y)
if is_writeable_array(x):
condition, x, y = xp.broadcast_arrays(condition, x, y)
x[condition] = y[condition]
return x
else:
return xp.where(condition, x, y)

__all__ = [
"array_namespace",
"device",
Expand All @@ -821,8 +989,11 @@ def size(x):
"is_ndonnx_namespace",
"is_pydata_sparse_array",
"is_pydata_sparse_namespace",
"is_writeable_array",
"size",
"to_device",
"at",
"iwhere",
]

_all_ignore = ['sys', 'math', 'inspect', 'warnings']
3 changes: 3 additions & 0 deletions docs/helper-functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ instead, which would be wrapped.
.. autofunction:: device
.. autofunction:: to_device
.. autofunction:: size
.. autofunction:: at
.. autofunction:: iwhere

Inspection Helpers
------------------
Expand All @@ -51,6 +53,7 @@ yet.
.. autofunction:: is_jax_array
.. autofunction:: is_pydata_sparse_array
.. autofunction:: is_ndonnx_array
.. autofunction:: is_writeable_array
.. autofunction:: is_numpy_namespace
.. autofunction:: is_cupy_namespace
.. autofunction:: is_torch_namespace
Expand Down

0 comments on commit 6884a34

Please sign in to comment.