Skip to content

Commit

Permalink
Self-review
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Dec 11, 2024
1 parent 18096a5 commit 2b18097
Show file tree
Hide file tree
Showing 8 changed files with 6,379 additions and 1,008 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ jobs:
strategy:
fail-fast: false
matrix:
environment: [ci-py310, ci-py313]
environment: [ci-py310, ci-py313, ci-backends]
runs-on: [ubuntu-latest]

steps:
Expand Down
7,170 changes: 6,259 additions & 911 deletions pixi.lock

Large diffs are not rendered by default.

23 changes: 23 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,27 @@ python = "~=3.10.0"
[tool.pixi.feature.py313.dependencies]
python = "~=3.13.0"

[tool.pixi.feature.backends.target.linux-64.dependencies]
cupy = "*"
pytorch = "*"
dask = "*"
sparse = ">=0.15"
jax = "*"

[tool.pixi.feature.backends.target.osx-arm64.dependencies]
# cupy = "*"
pytorch = "*"
dask = "*"
sparse = ">=0.15"
jax = "*"

[tool.pixi.feature.backends.target.win-64.dependencies]
cupy = "*"
# pytorch = "*"
dask = "*"
sparse = ">=0.15"
# jax = "*"

[tool.pixi.environments]
default = { solve-group = "default" }
lint = { features = ["lint"], solve-group = "default" }
Expand All @@ -135,6 +156,7 @@ docs = { features = ["docs"], solve-group = "default" }
dev = { features = ["lint", "tests", "docs", "dev"], solve-group = "default" }
ci-py310 = ["py310", "tests"]
ci-py313 = ["py313", "tests"]
ci-backends = ["py310", "tests", "backends"]


# pytest
Expand Down Expand Up @@ -232,6 +254,7 @@ ignore = [
"PLR09", # Too many <...>
"PLR2004", # Magic value used in comparison
"ISC001", # Conflicts with formatter
"PD008", # Use `.loc` instead of `.at`
]

[tool.ruff.lint.per-file-ignores]
Expand Down
178 changes: 83 additions & 95 deletions src/array_api_extra/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
array_namespace,
is_array_api_obj,
is_dask_array,
is_jax_array,
is_pydata_sparse_array,
is_writeable_array,
)

if typing.TYPE_CHECKING:
from ._lib._typing import Array, Index, ModuleType, Untyped
from ._lib._typing import Array, Index, ModuleType

__all__ = [
"at",
Expand Down Expand Up @@ -593,11 +595,6 @@ class at: # pylint: disable=invalid-name
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer
**kwargs:
If the backend supports an `at` method, any additional keyword
arguments are passed to it verbatim; e.g. this allows passing
``indices_are_sorted=True`` to JAX.
Returns
-------
Updated input array.
Expand Down Expand Up @@ -674,23 +671,7 @@ def __getitem__(self, idx: Index, /) -> at:
self.idx = idx
return self

def _common(
self,
at_op: str,
y: Array = _undef,
/,
copy: bool | None = True,
xp: ModuleType | None = None,
_is_update: bool = True,
**kwargs: Untyped,
) -> tuple[Array, None] | tuple[None, Array]:
"""Perform common prepocessing.
Returns
-------
If the operation can be resolved by at[], (return value, None)
Otherwise, (None, preprocessed x)
"""
def _check_args(self, /, copy: bool | None) -> None:
if self.idx is _undef:
msg = (

Check warning on line 676 in src/array_api_extra/_funcs.py

View check run for this annotation

Codecov / codecov/patch

src/array_api_extra/_funcs.py#L676

Added line #L676 was not covered by tests
"Index has not been set.\n"
Expand All @@ -702,64 +683,23 @@ def _common(
)
raise TypeError(msg)

Check warning on line 684 in src/array_api_extra/_funcs.py

View check run for this annotation

Codecov / codecov/patch

src/array_api_extra/_funcs.py#L684

Added line #L684 was not covered by tests

x = self.x

if copy not in (True, False, None):
msg = f"copy must be True, False, or None; got {copy!r}" # pyright: ignore[reportUnreachable]
raise ValueError(msg)

if copy is None:
writeable = is_writeable_array(x)
copy = _is_update and not writeable
elif copy:
writeable = None
elif _is_update:
writeable = is_writeable_array(x)
if not writeable:
msg = "Cannot modify parameter in place"
raise ValueError(msg)
else:
writeable = None

if copy:
try:
at_ = x.at
except AttributeError:
# Emulate at[] behaviour for non-JAX arrays
# with a copy followed by an update
if xp is None:
xp = array_namespace(x)
x = xp.asarray(x, copy=True)
if writeable is False:
# A copy of a read-only numpy array is writeable
# Note: this assumes that a copy of a writeable array is writeable
writeable = None
else:
# Use JAX's at[] or other library that with the same duck-type API
args = (y,) if y is not _undef else ()
return getattr(at_[self.idx], at_op)(*args, **kwargs), None

if _is_update:
if writeable is None:
writeable = is_writeable_array(x)
if not writeable:
# sparse crashes here
msg = f"Array {x} has no `at` method and is read-only"
raise ValueError(msg)

return None, x

def get(
self,
/,
copy: bool | None = True,
xp: ModuleType | None = None,
**kwargs: Untyped,
) -> Untyped:
"""Return ``x[idx]``. In addition to plain ``__getitem__``, this allows ensuring
that the output is either a copy or a view; it also allows passing
) -> Array:
"""Return ``xp.asarray(x[idx])``. In addition to plain ``__getitem__``, this allows
ensuring that the output is either a copy or a view; it also allows passing
keyword arguments to the backend.
"""
self._check_args(copy=copy)
x = self.x

if copy is False:
if is_array_api_obj(self.idx):
# Boolean index. Note that the array API spec
Expand All @@ -782,26 +722,81 @@ def get(
msg = "get() with a scalar index typically returns a copy"
raise ValueError(msg)

if is_dask_array(self.x):
msg = "get() on Dask arrays always returns a copy"
# Note: this is not the same list of backends as is_writeable_array()
if is_dask_array(x) or is_jax_array(x) or is_pydata_sparse_array(x):
msg = f"get() on {array_namespace(x)} arrays always returns a copy"
raise ValueError(msg)

Check warning on line 728 in src/array_api_extra/_funcs.py

View check run for this annotation

Codecov / codecov/patch

src/array_api_extra/_funcs.py#L727-L728

Added lines #L727 - L728 were not covered by tests

res, x = self._common("get", copy=copy, xp=xp, _is_update=False, **kwargs)
if res is not None:
return res
assert x is not None
return x[self.idx]
if is_jax_array(x):
# Use JAX's at[] or other library that with the same duck-type API
return x.at[self.idx].get()

Check warning on line 732 in src/array_api_extra/_funcs.py

View check run for this annotation

Codecov / codecov/patch

src/array_api_extra/_funcs.py#L732

Added line #L732 was not covered by tests

if xp is None:
xp = array_namespace(x)
# Note: when self.idx is a boolean mask, numpy always returns a deep copy.
# However, some backends may legitimately return a view when the mask can
# be downgraded to a slice, e.g. a[[True, True, False]] -> a[:2].
# Err on the side of caution and perform a double-copy in numpy.
return xp.asarray(x[self.idx], copy=copy)

def _update_common(
self,
at_op: str,
y: Array = _undef,
/,
copy: bool | None = True,
xp: ModuleType | None = None,
) -> tuple[Array, None] | tuple[None, Array]:
"""Perform common prepocessing to all update operations.
Returns
-------
If the operation can be resolved by at[], (return value, None)
Otherwise, (None, preprocessed x)
"""
x = self.x
if copy is None:
writeable = is_writeable_array(x)
copy = not writeable
elif copy:
writeable = None
else:
writeable = is_writeable_array(x)

if copy:
if is_jax_array(x):
# Use JAX's at[] or other library that with the same duck-type API
func = getattr(x.at[self.idx], at_op)
return func(y) if y is not _undef else func(), None

Check warning on line 770 in src/array_api_extra/_funcs.py

View check run for this annotation

Codecov / codecov/patch

src/array_api_extra/_funcs.py#L769-L770

Added lines #L769 - L770 were not covered by tests
# Emulate at[] behaviour for non-JAX arrays
# with a copy followed by an update
if xp is None:
xp = array_namespace(x)
x = xp.asarray(x, copy=True)
if writeable is False:
# A copy of a read-only numpy array is writeable
# Note: this assumes that a copy of a writeable array is writeable
writeable = None

if writeable is None:
writeable = is_writeable_array(x)
if not writeable:
# sparse crashes here
msg = f"Array {x} has no `at` method and is read-only"
raise ValueError(msg)

return None, x

def set(
self,
y: Array,
/,
copy: bool | None = True,
xp: ModuleType | None = None,
**kwargs: Untyped,
) -> Array:
"""Apply ``x[idx] = y`` and return the update array"""
res, x = self._common("set", y, copy=copy, xp=xp, **kwargs)
self._check_args(copy=copy)
res, x = self._update_common("set", y, copy=copy, xp=xp)
if res is not None:
return res

Check warning on line 801 in src/array_api_extra/_funcs.py

View check run for this annotation

Codecov / codecov/patch

src/array_api_extra/_funcs.py#L801

Added line #L801 was not covered by tests
assert x is not None
Expand All @@ -818,7 +813,6 @@ def _iop(
/,
copy: bool | None = True,
xp: ModuleType | None = None,
**kwargs: Untyped,
) -> Array:
"""x[idx] += y or equivalent in-place operation on a subset of x
Expand All @@ -829,7 +823,8 @@ def _iop(
Consider for example when x is a numpy array and idx is a fancy index, which
triggers a deep copy on __getitem__.
"""
res, x = self._common(at_op, y, copy=copy, xp=xp, **kwargs)
self._check_args(copy=copy)
res, x = self._update_common(at_op, y, copy=copy, xp=xp)
if res is not None:
return res

Check warning on line 829 in src/array_api_extra/_funcs.py

View check run for this annotation

Codecov / codecov/patch

src/array_api_extra/_funcs.py#L829

Added line #L829 was not covered by tests
assert x is not None
Expand All @@ -842,79 +837,72 @@ def add(
/,
copy: bool | None = True,
xp: ModuleType | None = None,
**kwargs: Untyped,
) -> Array:
"""Apply ``x[idx] += y`` and return the updated array"""
return self._iop("add", operator.add, y, copy=copy, xp=xp, **kwargs)
return self._iop("add", operator.add, y, copy=copy, xp=xp)

def subtract(
self,
y: Array,
/,
copy: bool | None = True,
xp: ModuleType | None = None,
**kwargs: Untyped,
) -> Array:
"""Apply ``x[idx] -= y`` and return the updated array"""
return self._iop("subtract", operator.sub, y, copy=copy, xp=xp, **kwargs)
return self._iop("subtract", operator.sub, y, copy=copy, xp=xp)

def multiply(
self,
y: Array,
/,
copy: bool | None = True,
xp: ModuleType | None = None,
**kwargs: Untyped,
) -> Array:
"""Apply ``x[idx] *= y`` and return the updated array"""
return self._iop("multiply", operator.mul, y, copy=copy, xp=xp, **kwargs)
return self._iop("multiply", operator.mul, y, copy=copy, xp=xp)

def divide(
self,
y: Array,
/,
copy: bool | None = True,
xp: ModuleType | None = None,
**kwargs: Untyped,
) -> Array:
"""Apply ``x[idx] /= y`` and return the updated array"""
return self._iop("divide", operator.truediv, y, copy=copy, xp=xp, **kwargs)
return self._iop("divide", operator.truediv, y, copy=copy, xp=xp)

def power(
self,
y: Array,
/,
copy: bool | None = True,
xp: ModuleType | None = None,
**kwargs: Untyped,
) -> Array:
"""Apply ``x[idx] **= y`` and return the updated array"""
return self._iop("power", operator.pow, y, copy=copy, xp=xp, **kwargs)
return self._iop("power", operator.pow, y, copy=copy, xp=xp)

def min(
self,
y: Array,
/,
copy: bool | None = True,
xp: ModuleType | None = None,
**kwargs: Untyped,
) -> Array:
"""Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array"""
if xp is None:
xp = array_namespace(self.x)
y = xp.asarray(y)
return self._iop("min", xp.minimum, y, copy=copy, xp=xp, **kwargs)
return self._iop("min", xp.minimum, y, copy=copy, xp=xp)

def max(
self,
y: Array,
/,
copy: bool | None = True,
xp: ModuleType | None = None,
**kwargs: Untyped,
) -> Array:
"""Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array"""
if xp is None:
xp = array_namespace(self.x)
y = xp.asarray(y)
return self._iop("max", xp.maximum, y, copy=copy, xp=xp, **kwargs)
return self._iop("max", xp.maximum, y, copy=copy, xp=xp)
6 changes: 6 additions & 0 deletions src/array_api_extra/_lib/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
device,
is_array_api_obj,
is_dask_array,
is_jax_array,
is_pydata_sparse_array,
is_writeable_array,
)
except ImportError:
Expand All @@ -16,6 +18,8 @@
device,
is_array_api_obj,
is_dask_array,
is_jax_array,
is_pydata_sparse_array,
is_writeable_array,
)

Expand All @@ -24,5 +28,7 @@
"device",
"is_array_api_obj",
"is_dask_array",
"is_jax_array",
"is_pydata_sparse_array",
"is_writeable_array",
)
Loading

0 comments on commit 2b18097

Please sign in to comment.