diff --git a/array_api_compat/__init__.py b/array_api_compat/__init__.py index 30b1d852..78f8374d 100644 --- a/array_api_compat/__init__.py +++ b/array_api_compat/__init__.py @@ -1,9 +1,9 @@ """ NumPy Array API compatibility library -This is a small wrapper around NumPy and CuPy that is compatible with the -Array API standard https://data-apis.org/array-api/latest/. See also NEP 47 -https://numpy.org/neps/nep-0047-array-api-standard.html. +This is a small wrapper around NumPy, CuPy, JAX and others that is compatible +with the Array API standard https://data-apis.org/array-api/latest/. +See also NEP 47 https://numpy.org/neps/nep-0047-array-api-standard.html. Unlike array_api_strict, this is not a strict minimal implementation of the Array API, but rather just an extension of the main NumPy namespace with diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index b011f08d..66542ae8 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -801,6 +801,174 @@ def size(x): return None return math.prod(x.shape) +def is_writeable_array(x): + """ + Return False if x.__setitem__ is expected to raise; True otherwise + """ + if is_numpy_array(x): + return x.flags.writeable + if is_jax_array(x): + return False + return True + +_undef = object() + +def at(x, idx=_undef, /): + """ + Update operations for read-only arrays. + + This implements ``jax.numpy.ndarray.at`` for all backends. + Writeable arrays may be updated in place; you should not rely on it. + + Keyword arguments (e.g. ``indices_are_sorted``) are passed to JAX and are + quietly ignored for backends that don't support them. + + Examples + -------- + Given either of these equivalent expressions:: + + x = at(x)[1].add(2) + x = at(x, 1).add(2) + + If x is a JAX array, they are the same as:: + + x = x.at[1].add(x) + + If x is a read-only numpy array, they are the same as:: + + x = x.copy() + x[1] += 2 + + Otherwise, they are the same as:: + + x[1] += 2 + + Warning + ------- + You should always immediately overwrite the parameter array:: + + x = at(x, 0).set(2) + + The anti-pattern below must be avoided, as it will result in different behaviour + on read-only versus writeable arrays: + + x = xp.asarray([0, 0, 0]) + y = at(x, 0).set(2) + z = at(x, 1).set(3) + + In the above example, y == [2, 0, 0] and z == [0, 3, 0] when x is read-only, + whereas y == z == [2, 3, 0] when x is writeable! + + See Also + -------- + https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html + """ + if is_jax_array(x): + return x.at + if is_numpy_array(x) and not x.flags.writeable: + x = x.copy() + return _InPlaceAt(x, idx) + +class _InPlaceAt: + """Helper of at(). + + Trivially implement jax.numpy.ndarray.at for other backends. + x is updated in place. + """ + __slots__ = ("x", "idx") + + def __init__(self, x, idx=_undef): + self.x = x + self.idx = idx + + def __getitem__(self, idx): + """ + Allow for the alternate syntax ``at(x)[start:stop:step]``, + which looks prettier than ``at(x, slice(start, stop, step))`` + and feels more intuitive coming from the JAX documentation. + """ + self.idx = idx + return self + + def _check_args(self, mode="promise_in_bounds", **kwargs): + if self.idx is _undef: + raise TypeError( + "Index has not been set.\n" + "Usage: either\n" + " at(x, idx).set(value)\n" + "or\n" + " at(x)[idx].set(value)\n" + "(same for all other methods)." + ) + if mode != "promise_in_bounds": + xp = array_namespace(self.x) + raise NotImplementedError( + f"mode='{mode}' is not supported for backend {xp.__name__}" + ) + + def set(self, y, /, **kwargs): + self._check_args(**kwargs) + self.x[self.idx] = y + return self.x + + def add(self, y, /, **kwargs): + self._check_args(**kwargs) + self.x[self.idx] += y + return self.x + + def subtract(self, y, /, **kwargs): + self._check_args(**kwargs) + self.x[self.idx] -= y + return self.x + + def multiply(self, y, /, **kwargs): + self._check_args(**kwargs) + self.x[self.idx] *= y + return self.x + + def divide(self, y, /, **kwargs): + self._check_args(**kwargs) + self.x[self.idx] /= y + return self.x + + def power(self, y, /, **kwargs): + self._check_args(**kwargs) + self.x[self.idx] **= y + return self.x + + def min(self, y, /, **kwargs): + self._check_args(**kwargs) + xp = array_namespace(self.x, y) + self.x[self.idx] = xp.minimum(self.x[self.idx], y) + return self.x + + def max(self, y, /, **kwargs): + self._check_args(**kwargs) + xp = array_namespace(self.x, y) + self.x[self.idx] = xp.maximum(self.x[self.idx], y) + return self.x + + def apply(self, ufunc, /, **kwargs): + self._check_args(**kwargs) + ufunc.at(self.x, self.idx) + return self.x + + def get(self, **kwargs): + self._check_args(**kwargs) + return self.x[self.idx] + +def iwhere(condition, x, y, /): + """Variant of xp.where(condition, x, y) which may or may not update + x in place, if it's possible and beneficial for performance. + """ + xp = array_namespace(condition, x, y) + if is_writeable_array(x): + condition, x, y = xp.broadcast_arrays(condition, x, y) + x[condition] = y[condition] + return x + else: + return xp.where(condition, x, y) + __all__ = [ "array_namespace", "device", @@ -821,8 +989,11 @@ def size(x): "is_ndonnx_namespace", "is_pydata_sparse_array", "is_pydata_sparse_namespace", + "is_writeable_array", "size", "to_device", + "at", + "iwhere", ] _all_ignore = ['sys', 'math', 'inspect', 'warnings'] diff --git a/docs/helper-functions.rst b/docs/helper-functions.rst index f44dc070..dfbd8674 100644 --- a/docs/helper-functions.rst +++ b/docs/helper-functions.rst @@ -36,6 +36,8 @@ instead, which would be wrapped. .. autofunction:: device .. autofunction:: to_device .. autofunction:: size +.. autofunction:: at +.. autofunction:: iwhere Inspection Helpers ------------------ @@ -51,6 +53,7 @@ yet. .. autofunction:: is_jax_array .. autofunction:: is_pydata_sparse_array .. autofunction:: is_ndonnx_array +.. autofunction:: is_writeable_array .. autofunction:: is_numpy_namespace .. autofunction:: is_cupy_namespace .. autofunction:: is_torch_namespace