Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: implement at #53

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

ENH: implement at #53

wants to merge 2 commits into from

Conversation

crusaderky
Copy link
Contributor

@crusaderky crusaderky commented Dec 9, 2024

Implement a new at(x, idx) or at(x)[idx] function, mocking the syntax of JAX's omonymous method .
This is propaedeutic to JAX support in libraries that support the Array API, e.g. scipy.

Moved from data-apis/array-api-compat#205

Blockers

pyproject.toml Outdated Show resolved Hide resolved
tests/test_at.py Outdated Show resolved Hide resolved
tests/test_at.py Outdated Show resolved Hide resolved
pyproject.toml Outdated Show resolved Hide resolved
@lucascolley lucascolley marked this pull request as ready for review December 9, 2024 14:30
@lucascolley lucascolley marked this pull request as draft December 9, 2024 14:31
@crusaderky
Copy link
Contributor Author

Rebased on top of #58

@crusaderky
Copy link
Contributor Author

crusaderky commented Dec 11, 2024

  • Reverted rebase onto Saner linters #58
  • Fixed all linters issues
  • Added unit tests for all backends
  • Removed **kwargs. Their main original appeal was allowing to pass performance tweaks to JAX, which are indices_are_sorted=True and unique_indices=True. However, I have now realised that the array API specification does not support integer array indices, at all, and both of these parameters are pointless without.
  • I'm having big second thoughts on get() - read below.

@crusaderky
Copy link
Contributor Author

crusaderky commented Dec 11, 2024

Any idea how to fix this? Looks like the ubuntu_latest VM has an obsolete driver (or more likely no driver)

ERROR cupy_backends.cuda.api.runtime.CUDARuntimeError: cudaErrorInsufficientDriver: CUDA driver version is insufficient for CUDA runtime version

Comment on lines 703 to 728
if copy is False:
if is_array_api_obj(self.idx):
# Boolean index. Note that the array API spec
# https://data-apis.org/array-api/latest/API_specification/indexing.html
# does not allow for list, tuple, and tuples of slices plus one or more
# one-dimensional array indices, although many backends support them.
# So this check will encounter a lot of false negatives in real life,
# which can be caught by testing the user code vs. array-api-strict.
msg = "get() with an array index always returns a copy"
raise ValueError(msg)

# Prevent scalar indices together with copy=False.
# Even if some backends may return a scalar view of the original, we chose to be
# strict here beceause some other backends, such as numpy, definitely don't.
tup_idx = self.idx if isinstance(self.idx, tuple) else (self.idx,)
if any(
i is not None and i is not Ellipsis and not isinstance(i, slice)
for i in tup_idx
):
msg = "get() with a scalar index typically returns a copy"
raise ValueError(msg)

# Note: this is not the same list of backends as is_writeable_array()
if is_dask_array(x) or is_jax_array(x) or is_pydata_sparse_array(x):
msg = f"get() on {array_namespace(x)} arrays always returns a copy"
raise ValueError(msg)
Copy link
Contributor Author

@crusaderky crusaderky Dec 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These checks are very brittle and I'm not comfortable going on with get() as it is.

I can see two ways forward:

  1. Write a new function in array-api-compat, getitem_returns_view(x: Array, idx: Index) -> bool that explores all the nooks and crannies of all the possible intersections of arrays and indices, then use that function here
  2. Remove get() altogether. I have a strong suspicion it may not be needed in scipy to begin with. get(copy=False) is not portable anyway, and get(copy=True) can be trivially rewritten as asarray(x[idx], copy=True)

Copy link
Contributor Author

@crusaderky crusaderky Dec 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only use I can think of out of get(copy=True) would be to avoid an unnecessary double copy when idx is a boolean mask, and that's assuming that the user either
a. doesn't know ahead of time that they're using a boolean mask (possible, but not terribly likely), or
b. doesn't trust that a x[bool_mask_idx] will always return a deep copy no matter what, with all backends, e.g. no backend will try to coerce the mask into a slice

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've removed get(), at least for the time being.

@lucascolley
Copy link
Collaborator

I don't think we can get GPU CI without paying someone for it, cc @rgommers .

@crusaderky
Copy link
Contributor Author

I don't think we can get GPU CI without paying someone for it, cc @rgommers .

Can cupy run on a CPU-only host?
I've disabled it in CI for the time being

@rgommers
Copy link
Member

I don't think we can get GPU CI without paying someone for it, cc @rgommers .

Yep, that's on my radar to push forward this month, on multiple projects. Please feel free to open a new issue and assign it to me. I think we can hook up a shared GPU runner between this project and array-api-compat.

Can cupy run on a CPU-only host?

I don't think so.

Copy link
Collaborator

@lucascolley lucascolley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall this looks great to me, thanks! Added some comments.

I'll take another look here before it goes in. In the meantime, hopefully the upstream work goes smoothly, and it would be fantastic if @jakevdp / @rgommers / anyone else could take a closer look at the implementation!

pyproject.toml Outdated Show resolved Hide resolved
src/array_api_extra/_funcs.py Outdated Show resolved Hide resolved
src/array_api_extra/_funcs.py Outdated Show resolved Hide resolved
src/array_api_extra/_funcs.py Outdated Show resolved Hide resolved
src/array_api_extra/_funcs.py Show resolved Hide resolved
src/array_api_extra/_funcs.py Outdated Show resolved Hide resolved
src/array_api_extra/_funcs.py Outdated Show resolved Hide resolved
src/array_api_extra/_funcs.py Outdated Show resolved Hide resolved
src/array_api_extra/_funcs.py Outdated Show resolved Hide resolved
src/array_api_extra/_funcs.py Outdated Show resolved Hide resolved
@crusaderky crusaderky marked this pull request as ready for review December 12, 2024 11:18
@crusaderky crusaderky changed the title WIP New at() method [DNM] New at() method Dec 12, 2024
@crusaderky
Copy link
Contributor Author

I gave a final round of polish. This is ready for final review and approval.
Note the blocker PRs and DNMs before it can be merged though.

@crusaderky crusaderky mentioned this pull request Dec 12, 2024
@crusaderky
Copy link
Contributor Author

I don't think we can get GPU CI without paying someone for it, cc @rgommers .

Yep, that's on my radar to push forward this month, on multiple projects. Please feel free to open a new issue and assign it to me. I think we can hook up a shared GPU runner between this project and array-api-compat.

Can cupy run on a CPU-only host?

I don't think so.

#60

@crusaderky crusaderky changed the title [DNM] New at() method [DNM] New at() function Dec 12, 2024
@lucascolley
Copy link
Collaborator

I resolved the merge conflicts after adding numpydoc to pre-commit and made some minor tweaks to the docs.

@crusaderky crusaderky changed the title [DNM] New at() function New at() function Dec 17, 2024
@crusaderky
Copy link
Contributor Author

@lucascolley if you don't mind building against array-api-compat git tip until their next release, this PR is ready for final review and merge

pyproject.toml Outdated Show resolved Hide resolved
Copy link
Collaborator

@lucascolley lucascolley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lucascolley if you don't mind building against array-api-compat git tip until their next release, this PR is ready for final review and merge

That's fine with me, but might be cleaner to get an array-api-compat release out first unless there are any big blockers.

Let's give @rgommers and @jakevdp time to take a look if they would like to.

Once (or before) this is merged, do you think you could make a PR to my branch at scikit-learn/scikit-learn#30340? I think we will have to transition scikit-learn from using array-api-compat as an optional dependency to vendoring it, but that sounds feasible based on scikit-learn/scikit-learn#30367 (comment).

@lucascolley lucascolley changed the title New at() function ENH: implement at Dec 17, 2024
@crusaderky
Copy link
Contributor Author

Once (or before) this is merged, do you think you could make a PR to my branch at scikit-learn/scikit-learn#30340? I think we will have to transition scikit-learn from using array-api-compat as an optional dependency to vendoring it, but that sounds feasible based on scikit-learn/scikit-learn#30367 (comment).

lucascolley/scikit-learn#2

msg = "Index has already been set"
raise ValueError(msg)
self._idx = idx
return self
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it would make sense to return a shallow copy rather than mutating self. Otherwise, someone might write something like this and be surprised by the behavior:

getter = at(x)
y = getter[0].add(1)
z = getter[1].add(2)

if res is not None:
return res
assert x is not None
x[self._idx] = elwise_op(x[self._idx], y)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As currently implemented, the _iop path has different semantics for repeated indices between JAX and NumPy:

>>> import numpy as np
>>> x = np.zeros(4)
>>> idx = np.array([1, 2, 2, 3, 3, 3])
>>> x[idx] = x[idx] + 1
>>> x
array([0., 1., 1., 1.])

>>> import jax.numpy as jnp
>>> x = jnp.zeros(4)
>>> idx = jnp.array([1, 2, 2, 3, 3, 3])
>>> x.at[idx].add(1)
Array([0., 1., 2., 3.], dtype=float32)

At the very least, the difference should be documented.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request new function
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants