Skip to content

Commit

Permalink
WIP at() method
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Dec 9, 2024
1 parent 00b2ab7 commit 162264b
Show file tree
Hide file tree
Showing 7 changed files with 477 additions and 4 deletions.
1 change: 1 addition & 0 deletions docs/api-reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
:nosignatures:
:toctree: generated
at
atleast_nd
cov
create_diagonal
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,9 @@ ignore = [
"PLR09", # Too many <...>
"PLR2004", # Magic value used in comparison
"ISC001", # Conflicts with formatter
"EM101", # raw-string-in-exception
"EM102", # f-string-in-exception
"PD008", # pandas-use-of-dot-at
]
isort.required-imports = ["from __future__ import annotations"]

Expand Down
12 changes: 11 additions & 1 deletion src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990

from ._funcs import atleast_nd, cov, create_diagonal, expand_dims, kron, setdiff1d, sinc
from ._funcs import (
at,
atleast_nd,
cov,
create_diagonal,
expand_dims,
kron,
setdiff1d,
sinc,
)

__version__ = "0.3.3"

# pylint: disable=duplicate-code
__all__ = [
"__version__",
"at",
"atleast_nd",
"cov",
"create_diagonal",
Expand Down
284 changes: 281 additions & 3 deletions src/array_api_extra/_funcs.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990

import typing
import operator
import warnings
from typing import TYPE_CHECKING, Any, Callable, Literal

if typing.TYPE_CHECKING:
if TYPE_CHECKING:
from ._lib._typing import Array, ModuleType

from ._lib import _utils
from ._lib._compat import array_namespace
from ._lib._compat import (
array_namespace,
is_array_api_obj,
is_writeable_array,
)

__all__ = [
"at",
"atleast_nd",
"cov",
"create_diagonal",
Expand Down Expand Up @@ -546,3 +552,275 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
x, x, xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=x.device)
)
return xp.sin(y) / y


def _is_fancy_index(idx: object | tuple[object, ...]) -> bool:
if not isinstance(idx, tuple):
idx = (idx,)
return any(isinstance(i, (list, tuple)) or is_array_api_obj(i) for i in idx)


_undef = object()


class at:
"""
Update operations for read-only arrays.
This implements ``jax.numpy.ndarray.at`` for all backends.
Parameters
----------
x : array
Input array.
copy : bool, optional
True (default)
Ensure that the inputs are not modified.
False
Ensure that the update operation writes back to the input.
Raise ValueError if a copy cannot be avoided.
None
The array parameter *may* be modified in place if it is possible and
beneficial for performance.
You should not reuse it after calling this function.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer
Additionally, if the backend supports an `at` method, any additional keyword
arguments are passed to it verbatim; e.g. this allows passing
``indices_are_sorted=True`` to JAX.
Returns
-------
Updated input array.
Examples
--------
Given either of these equivalent expressions::
x = at(x)[1].add(2, copy=None)
x = at(x, 1).add(2, copy=None)
If x is a JAX array, they are the same as::
x = x.at[1].add(2)
If x is a read-only numpy array, they are the same as::
x = x.copy()
x[1] += 2
Otherwise, they are the same as::
x[1] += 2
Warning
-------
When you use copy=None, you should always immediately overwrite
the parameter array::
x = at(x, 0).set(2, copy=None)
The anti-pattern below must be avoided, as it will result in different behaviour
on read-only versus writeable arrays::
x = xp.asarray([0, 0, 0])
y = at(x, 0).set(2, copy=None)
z = at(x, 1).set(3, copy=None)
In the above example, ``x == [0, 0, 0]``, ``y == [2, 0, 0]`` and z == ``[0, 3, 0]``
when x is read-only, whereas ``x == y == z == [2, 3, 0]`` when x is writeable!
Warning
-------
The behaviour of update methods when the index is an array of integers which
contains multiple occurrences of the same index is undefined;
e.g. ``at(x, [0, 0]).set(2)``
Note
----
`sparse <https://sparse.pydata.org/>`_ is not supported by update methods yet.
See Also
--------
`jax.numpy.ndarray.at <https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html>`_
"""

x: Array
idx: Any
__slots__ = ("x", "idx")

def __init__(self, x: Array, idx: Any = _undef, /):
self.x = x
self.idx = idx

def __getitem__(self, idx: Any) -> Any:
"""Allow for the alternate syntax ``at(x)[start:stop:step]``,
which looks prettier than ``at(x, slice(start, stop, step))``
and feels more intuitive coming from the JAX documentation.
"""
if self.idx is not _undef:
raise ValueError("Index has already been set")
self.idx = idx
return self

def _common(
self,
at_op: str,
y: Array = _undef,
/,
copy: bool | None | Literal["_force_false"] = True,
xp: ModuleType | None = None,
_is_update: bool = True,
**kwargs: Any,
) -> tuple[Any, None] | tuple[None, Array]:
"""Perform common prepocessing.
Returns
-------
If the operation can be resolved by at[], (return value, None)
Otherwise, (None, preprocessed x)
"""
if self.idx is _undef:
raise TypeError(
"Index has not been set.\n"
"Usage: either\n"
" at(x, idx).set(value)\n"
"or\n"
" at(x)[idx].set(value)\n"
"(same for all other methods)."
)

x = self.x

if copy is True:
writeable = None
elif copy is False:
writeable = is_writeable_array(x)
if not writeable:
raise ValueError("Cannot modify parameter in place")
elif copy is None:
writeable = is_writeable_array(x)
copy = _is_update and not writeable
elif copy == "_force_false": # type: ignore[redundant-expr]
# __getitem__ with fancy index on a numpy array
writeable = True
copy = False
else:
raise ValueError(f"Invalid value for copy: {copy!r}")

if copy:
try:
at_ = x.at
except AttributeError:
# Emulate at[] behaviour for non-JAX arrays
# with a copy followed by an update
if xp is None:
xp = array_namespace(x)
# Create writeable copy of read-only numpy array
x = xp.asarray(x, copy=True)
else:
# Use JAX's at[] or other library that with the same duck-type API
args = (y,) if y is not _undef else ()
return getattr(at_[self.idx], at_op)(*args, **kwargs), None

# This blindly expects that if x is writeable its copy is also writeable
if _is_update:
if writeable is None:
writeable = is_writeable_array(x)
if not writeable:
# sparse crashes here
raise ValueError(f"Array {x} has no `at` method and is read-only")

return None, x

def get(self, **kwargs: Any) -> Any:
"""Return ``x[idx]``. In addition to plain ``__getitem__``, this allows ensuring
that the output is either a copy or a view; it also allows passing
keyword arguments to the backend.
"""
# __getitem__ with a fancy index always returns a copy.
# Avoid an unnecessary double copy.
# If copy is forced to False, raise.
# FIXME This is an assumption based on numpy behaviour; it may not hold true
# for other backends. Namely, a backend could decide to conditionally return a
# view if the index can be coerced into a slice.
if _is_fancy_index(self.idx):
if kwargs.get("copy") is False:
raise TypeError(
"Indexing an array with a fancy index always results in a copy"
)
# Skip copy inside _common, even if array is not writeable
kwargs["copy"] = "_force_false"

res, x = self._common("get", _is_update=False, **kwargs)
if res is not None:
return res
assert x is not None
return x[self.idx]

def set(self, y: Array, /, **kwargs: Any) -> Array:
"""Apply ``x[idx] = y`` and return the update array"""
res, x = self._common("set", y, **kwargs)
if res is not None:
return res
assert x is not None
x[self.idx] = y
return x

def _iop(
self,
at_op: str,
elwise_op: Callable[[Array, Array], Array],
y: Array,
/,
**kwargs: Any,
) -> Array:
"""x[idx] += y or equivalent in-place operation on a subset of x
which is the same as saying
x[idx] = x[idx] + y
Note that this is not the same as
operator.iadd(x[idx], y)
Consider for example when x is a numpy array and idx is a fancy index, which
triggers a deep copy on __getitem__.
"""
res, x = self._common(at_op, y, **kwargs)
if res is not None:
return res
assert x is not None
x[self.idx] = elwise_op(x[self.idx], y)
return x

def add(self, y: Array, /, **kwargs: Any) -> Array:
"""Apply ``x[idx] += y`` and return the updated array"""
return self._iop("add", operator.add, y, **kwargs)

def subtract(self, y: Array, /, **kwargs: Any) -> Array:
"""Apply ``x[idx] -= y`` and return the updated array"""
return self._iop("subtract", operator.sub, y, **kwargs)

def multiply(self, y: Array, /, **kwargs: Any) -> Array:
"""Apply ``x[idx] *= y`` and return the updated array"""
return self._iop("multiply", operator.mul, y, **kwargs)

def divide(self, y: Array, /, **kwargs: Any) -> Array:
"""Apply ``x[idx] /= y`` and return the updated array"""
return self._iop("divide", operator.truediv, y, **kwargs)

def power(self, y: Array, /, **kwargs: Any) -> Array:
"""Apply ``x[idx] **= y`` and return the updated array"""
return self._iop("power", operator.pow, y, **kwargs)

def min(self, y: Array, /, **kwargs: Any) -> Array:
"""Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array"""
xp = array_namespace(self.x)
y = xp.asarray(y)
return self._iop("min", xp.minimum, y, **kwargs)

def max(self, y: Array, /, **kwargs: Any) -> Array:
"""Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array"""
xp = array_namespace(self.x)
y = xp.asarray(y)
return self._iop("max", xp.maximum, y, **kwargs)
5 changes: 5 additions & 0 deletions src/array_api_extra/_lib/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,19 @@
from ..._array_api_compat_vendor import ( # pyright: ignore[reportMissingImports]
array_namespace, # pyright: ignore[reportUnknownVariableType]
device, # pyright: ignore[reportUnknownVariableType]
is_array_api_obj, # pyright: ignore[reportUnknownVariableType]
is_writeable_array, # pyright: ignore[reportUnknownVariableType]
)
except ImportError:
from array_api_compat import ( # pyright: ignore[reportMissingTypeStubs]
array_namespace, # pyright: ignore[reportUnknownVariableType]
device,
is_writeable_array, # pyright: ignore[reportUnknownVariableType,reportAttributeAccessIssue]
)

__all__ = [
"array_namespace",
"device",
"is_array_api_obj",
"is_writeable_array",
]
2 changes: 2 additions & 0 deletions src/array_api_extra/_lib/_compat.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@ def array_namespace(
use_compat: bool | None = None,
) -> ArrayModule: ...
def device(x: Array, /) -> Device: ...
def is_array_api_obj(x: object, /) -> bool: ...
def is_writeable_array(x: object, /) -> bool: ...
Loading

0 comments on commit 162264b

Please sign in to comment.