Skip to content

Commit

Permalink
WIP at() method
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Dec 9, 2024
1 parent 00b2ab7 commit dd068d2
Show file tree
Hide file tree
Showing 12 changed files with 452 additions and 17 deletions.
1 change: 1 addition & 0 deletions docs/api-reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
:nosignatures:
:toctree: generated
at
atleast_nd
cov
create_diagonal
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ run.source = ["array_api_extra"]
report.exclude_also = [
'\.\.\.',
'if typing.TYPE_CHECKING:',
'if TYPE_CHECKING:',
]


Expand Down Expand Up @@ -235,6 +236,7 @@ ignore = [
"PLR09", # Too many <...>
"PLR2004", # Magic value used in comparison
"ISC001", # Conflicts with formatter
"PD008", # pandas-use-of-dot-at
]
isort.required-imports = ["from __future__ import annotations"]

Expand Down
12 changes: 11 additions & 1 deletion src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990

from ._funcs import atleast_nd, cov, create_diagonal, expand_dims, kron, setdiff1d, sinc
from ._funcs import (
at,
atleast_nd,
cov,
create_diagonal,
expand_dims,
kron,
setdiff1d,
sinc,
)

__version__ = "0.3.3"

# pylint: disable=duplicate-code
__all__ = [
"__version__",
"at",
"atleast_nd",
"cov",
"create_diagonal",
Expand Down
282 changes: 279 additions & 3 deletions src/array_api_extra/_funcs.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990

import typing
import operator
import warnings
from typing import TYPE_CHECKING, Any, Callable

if typing.TYPE_CHECKING:
if TYPE_CHECKING:
from ._lib._typing import Array, ModuleType

from ._lib import _utils
from ._lib._compat import array_namespace
from ._lib._compat import (
array_namespace,
is_array_api_obj,
is_writeable_array,
)

__all__ = [
"at",
"atleast_nd",
"cov",
"create_diagonal",
Expand Down Expand Up @@ -546,3 +552,273 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
x, x, xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=x.device)
)
return xp.sin(y) / y


_undef = object()


class at:
"""
Update operations for read-only arrays.
This implements ``jax.numpy.ndarray.at`` for all backends.
Parameters
----------
x : array
Input array.
idx : index, optional
You may use two alternate syntaxes::
at(x, idx).set(value) # or get(), add(), etc.
at(x)[idx].set(value)
copy : bool, optional
True (default)
Ensure that the inputs are not modified.
False
Ensure that the update operation writes back to the input.
Raise ValueError if a copy cannot be avoided.
None
The array parameter *may* be modified in place if it is possible and
beneficial for performance.
You should not reuse it after calling this function.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer
Additionally, if the backend supports an `at` method, any additional keyword
arguments are passed to it verbatim; e.g. this allows passing
``indices_are_sorted=True`` to JAX.
Returns
-------
Updated input array.
Examples
--------
Given either of these equivalent expressions::
x = at(x)[1].add(2, copy=None)
x = at(x, 1).add(2, copy=None)
If x is a JAX array, they are the same as::
x = x.at[1].add(2)
If x is a read-only numpy array, they are the same as::
x = x.copy()
x[1] += 2
Otherwise, they are the same as::
x[1] += 2
Warning
-------
When you use copy=None, you should always immediately overwrite
the parameter array::
x = at(x, 0).set(2, copy=None)
The anti-pattern below must be avoided, as it will result in different behaviour
on read-only versus writeable arrays::
x = xp.asarray([0, 0, 0])
y = at(x, 0).set(2, copy=None)
z = at(x, 1).set(3, copy=None)
In the above example, ``x == [0, 0, 0]``, ``y == [2, 0, 0]`` and z == ``[0, 3, 0]``
when x is read-only, whereas ``x == y == z == [2, 3, 0]`` when x is writeable!
Warning
-------
The behaviour of update methods when the index is an array of integers which
contains multiple occurrences of the same index is undefined;
e.g. ``at(x, [0, 0]).set(2)``
Note
----
`sparse <https://sparse.pydata.org/>`_ is not supported by update methods yet.
See Also
--------
`jax.numpy.ndarray.at <https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html>`_
"""

x: Array
idx: Any
__slots__ = ("x", "idx")

def __init__(self, x: Array, idx: Any = _undef, /):
self.x = x
self.idx = idx

def __getitem__(self, idx: Any) -> Any:
"""Allow for the alternate syntax ``at(x)[start:stop:step]``,
which looks prettier than ``at(x, slice(start, stop, step))``
and feels more intuitive coming from the JAX documentation.
"""
if self.idx is not _undef:
msg = "Index has already been set"
raise ValueError(msg)
self.idx = idx
return self

def _common(
self,
at_op: str,
y: Array = _undef,
/,
copy: bool | None = True,
xp: ModuleType | None = None,
_is_update: bool = True,
**kwargs: Any,
) -> tuple[Any, None] | tuple[None, Array]:
"""Perform common prepocessing.
Returns
-------
If the operation can be resolved by at[], (return value, None)
Otherwise, (None, preprocessed x)
"""
if self.idx is _undef:
msg = (
"Index has not been set.\n"
"Usage: either\n"
" at(x, idx).set(value)\n"
"or\n"
" at(x)[idx].set(value)\n"
"(same for all other methods)."
)
raise TypeError(msg)

x = self.x

if copy is True:
writeable = None
elif copy is False:
writeable = is_writeable_array(x)
if not writeable:
msg = "Cannot modify parameter in place"
raise ValueError(msg)
elif copy is None:
writeable = is_writeable_array(x)
copy = _is_update and not writeable
else:
msg = f"Invalid value for copy: {copy!r}"
raise ValueError(msg)

if copy:
try:
at_ = x.at
except AttributeError:
# Emulate at[] behaviour for non-JAX arrays
# with a copy followed by an update
if xp is None:
xp = array_namespace(x)
# Create writeable copy of read-only numpy array
x = xp.asarray(x, copy=True)
if writeable is False:
# A copy of a read-only numpy array is writeable
writeable = None
else:
# Use JAX's at[] or other library that with the same duck-type API
args = (y,) if y is not _undef else ()
return getattr(at_[self.idx], at_op)(*args, **kwargs), None

if _is_update:
if writeable is None:
writeable = is_writeable_array(x)
if not writeable:
# sparse crashes here
msg = f"Array {x} has no `at` method and is read-only"
raise ValueError(msg)

return None, x

def get(self, **kwargs: Any) -> Any:
"""Return ``x[idx]``. In addition to plain ``__getitem__``, this allows ensuring
that the output is either a copy or a view; it also allows passing
keyword arguments to the backend.
"""
if kwargs.get("copy") is False and (
is_array_api_obj(self.idx)
or isinstance(self.idx, tuple)
and any(is_array_api_obj(i) for i in self.idx)
):
# Fancy index. Note that the array API spec does not allow for
# list, tuple, or numpy arrays although many backends support them.
msg = "get() with an array index always returns in a copy"
raise ValueError(msg)

res, x = self._common("get", _is_update=False, **kwargs)
if res is not None:
return res
assert x is not None
return x[self.idx]

def set(self, y: Array, /, **kwargs: Any) -> Array:
"""Apply ``x[idx] = y`` and return the update array"""
res, x = self._common("set", y, **kwargs)
if res is not None:
return res
assert x is not None
x[self.idx] = y
return x

def _iop(
self,
at_op: str,
elwise_op: Callable[[Array, Array], Array],
y: Array,
/,
**kwargs: Any,
) -> Array:
"""x[idx] += y or equivalent in-place operation on a subset of x
which is the same as saying
x[idx] = x[idx] + y
Note that this is not the same as
operator.iadd(x[idx], y)
Consider for example when x is a numpy array and idx is a fancy index, which
triggers a deep copy on __getitem__.
"""
res, x = self._common(at_op, y, **kwargs)
if res is not None:
return res
assert x is not None
x[self.idx] = elwise_op(x[self.idx], y)
return x

def add(self, y: Array, /, **kwargs: Any) -> Array:
"""Apply ``x[idx] += y`` and return the updated array"""
return self._iop("add", operator.add, y, **kwargs)

def subtract(self, y: Array, /, **kwargs: Any) -> Array:
"""Apply ``x[idx] -= y`` and return the updated array"""
return self._iop("subtract", operator.sub, y, **kwargs)

def multiply(self, y: Array, /, **kwargs: Any) -> Array:
"""Apply ``x[idx] *= y`` and return the updated array"""
return self._iop("multiply", operator.mul, y, **kwargs)

def divide(self, y: Array, /, **kwargs: Any) -> Array:
"""Apply ``x[idx] /= y`` and return the updated array"""
return self._iop("divide", operator.truediv, y, **kwargs)

def power(self, y: Array, /, **kwargs: Any) -> Array:
"""Apply ``x[idx] **= y`` and return the updated array"""
return self._iop("power", operator.pow, y, **kwargs)

def min(self, y: Array, /, **kwargs: Any) -> Array:
"""Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array"""
xp = array_namespace(self.x)
y = xp.asarray(y)
return self._iop("min", xp.minimum, y, **kwargs)

def max(self, y: Array, /, **kwargs: Any) -> Array:
"""Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array"""
xp = array_namespace(self.x)
y = xp.asarray(y)
return self._iop("max", xp.maximum, y, **kwargs)
10 changes: 8 additions & 2 deletions src/array_api_extra/_lib/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,20 @@
from ..._array_api_compat_vendor import ( # pyright: ignore[reportMissingImports]
array_namespace, # pyright: ignore[reportUnknownVariableType]
device, # pyright: ignore[reportUnknownVariableType]
is_array_api_obj, # pyright: ignore[reportUnknownVariableType]
is_writeable_array, # pyright: ignore[reportUnknownVariableType]
)
except ImportError:
from array_api_compat import ( # pyright: ignore[reportMissingTypeStubs]
array_namespace, # pyright: ignore[reportUnknownVariableType]
device,
is_array_api_obj, # pyright: ignore[reportUnknownVariableType]
is_writeable_array, # pyright: ignore[reportUnknownVariableType,reportAttributeAccessIssue]
)

__all__ = [
__all__ = (
"array_namespace",
"device",
]
"is_array_api_obj",
"is_writeable_array",
)
2 changes: 2 additions & 0 deletions src/array_api_extra/_lib/_compat.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@ def array_namespace(
use_compat: bool | None = None,
) -> ArrayModule: ...
def device(x: Array, /) -> Device: ...
def is_array_api_obj(x: object, /) -> bool: ...
def is_writeable_array(x: object, /) -> bool: ...
7 changes: 3 additions & 4 deletions src/array_api_extra/_lib/_typing.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990

import typing
from types import ModuleType
from typing import Any
from typing import TYPE_CHECKING, Any

if typing.TYPE_CHECKING:
if TYPE_CHECKING:
from typing_extensions import override

# To be changed to a Protocol later (see data-apis/array-api#589)
Expand All @@ -18,5 +17,5 @@ def no_op_decorator(f): # pyright: ignore[reportUnreachable]
override = no_op_decorator

__all__ = ["ModuleType", "override"]
if typing.TYPE_CHECKING:
if TYPE_CHECKING:
__all__ += ["Array", "Device"]
Loading

0 comments on commit dd068d2

Please sign in to comment.