Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Test for read-only arrays #205

Merged
merged 2 commits into from
Dec 16, 2024
Merged

ENH: Test for read-only arrays #205

merged 2 commits into from
Dec 16, 2024

Conversation

crusaderky
Copy link
Contributor

@crusaderky crusaderky commented Nov 25, 2024

New public function is_writeable_array, which introduce transparent support for read-only backends such as JAX (but more may be added in the future).

[EDIT] this PR also implemented an at function, mocking the syntax of JAX's omonymous method.
This function will be proposed in array-api-extra instead.

@crusaderky crusaderky changed the title Abstractions for read-only arrays [WIP] Abstractions for read-only arrays Nov 25, 2024
@crusaderky crusaderky force-pushed the jax branch 2 times, most recently from 3c2f31d to 6884a34 Compare November 25, 2024 18:28
@asmeurer
Copy link
Member

This seems related to the suggestions at #146. Ping @mdhaber @lucascolley to check if these helpers will be useful in SciPy.

@lucascolley
Copy link
Contributor

x-ref data-apis/array-api#609, cc @jakevdp @rgommers

@mdhaber
Copy link

mdhaber commented Nov 25, 2024

Yes, it looks like it!

Copy link
Contributor

@lucascolley lucascolley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, looks useful! The diffs will be dauntingly large I suppose, but hopefully trivial enough!

Comment on lines 866 to 870
if is_jax_array(x):
return x.at
if is_numpy_array(x) and not x.flags.writeable:
x = x.copy()
return _InPlaceAt(x, idx)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if I understand correctly, this implementation works for:

  • JAX
  • NumPy
  • Mutable arrays (arrays for which the in-place updates work)

It doesn't work for immutable arrays coming from other libraries, for the reasons discussed in data-apis/array-api#845.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It works for all arrays explicitly listed in array-api-compat, including numpy subclasses.
It won't work for immutable arrays from other, unknown libraries - in fact, they won't be even recognized as immutable. Without a standardized way for is_immutable_array to detect such a use case (__array_writeable__ interface?) I would not know how to implement it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would attempting to mutate the array as a fallback work, or could that lead to undesired side-effects?

# Something along these lines, but more robust
a = x[0]
try:
    x[0] = 0
    x[0] = a
    mutable = True
except:
    mutable = False

Copy link
Contributor Author

@crusaderky crusaderky Nov 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that reimplementing at[] without the ability to update and without specific knowledge of the library at hand would be extraordinarily inefficient.

Consider at(x, 0).add(1). Implementing it on top of where would be extremely expensive; an alternative would be to break the array down and rebuild it with concatenate and stack, which would be very complicated and probably fragile.

The snippet of code you just wrote would perform very poorly in many cases. Consider:

  • a library like dask: the x[0] = a line adds extra labour into the graph, which is very likely nontrivial
  • any library where the memory transparently moves from device to host and vice versa: a = x[0] would likely cause either a page or the whole array to do just that.

A somewhat slightly more reasonable alternative would be to

  1. blindly try the update
  2. if the update fails, try if by any chance the library has a at[] method with exactly the same API as JAX

At the moment we're only doing (1).
My personal opinion would be to explicitly cater for these libraries if and when they crop up.
Is this whole discussion hypothetical, or do we know of a specific read-only library other than JAX?

Copy link
Contributor Author

@crusaderky crusaderky Nov 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just realised that sparse does not support updates, so it is going to face the same issues in scipy as JAX. Has there been any discussion about it?

array_api_compat/common/_helpers.py Outdated Show resolved Hide resolved
array_api_compat/common/_helpers.py Outdated Show resolved Hide resolved
array_api_compat/common/_helpers.py Outdated Show resolved Hide resolved
array_api_compat/common/_helpers.py Outdated Show resolved Hide resolved
@crusaderky
Copy link
Contributor Author

crusaderky commented Nov 26, 2024

The diffs will be dauntingly large I suppose

Do you mean the diffs inside scipy? Yes they will. Not much that can be done about it I'm afraid.

I gave this some more thought and reworked the PR.
Now

  • where is a wrapper around xp.where, which adds copy=True/False/None (defaults to True).
  • at is a simple wrapper around jnp.ndarray.at, which adds the copy=True/False/None (defaults to True) parameter. When x is not a JAX array, at behaves the same as JAX by default (e.g. deep-copy everything).

In scipy and similar libraries, we should replace all instances of

x[idx] += y

with

x = at(x, idx).add(y, copy=None)

(read below for the masked use case).

Offline comment by @rgommers :

this targets the wrong repo. array-api-compat is a compatibilty layer for things that are actually in the standard, or about to land there. array-api-extra is for things on top.

There is discussion elsewhere that something should be pushed into the standard.
I would like to suggest exactly that for this PR. I think we should add to the API standard:

  • An at[] method, as already implemented in JAX plus the copy=True optional parameter, plus additional backend-specific kwargs (e.g. the standard would omit things like indices_are_sorted). The behaviour for out-of-bounds indices should likewise be left to undefined.
  • adding the copy=True parameter to where.

Masked JAX arrays

I've deliberately omitted special-casing for masked JAX arrays, unlike in @rgommers 's POC here https://github.com/scipy/scipy/compare/main...rgommers:scipy:array-types-inplace-ops?expand=1.

IMHO, the problem with the POC code e.g.

    if is_jax(xp):
        if hasattr(idx, 'dtype') and xp.isdtype(idx.dtype, 'bool'):
            x = xp.where(idx, x + y, x)
        else:
            x = x.at[idx].add(y)
    else:
        x[idx] += val

is that it produces incorrect results when y has shape other than () or (1, ).

This strongly feels to me less of a generic problem with read-only arrays (which array-api-compat and scipy should cater for) and more of a JAX-specific quirk (needing to know all shapes in advance).
So I think we should implement the above snippet as a rewrite rule inside jax.jit:

  • IF idx is a JAX array of bools
  • AND y has shape () or (1, )
  • THEN rewrite x.at[idx].add(y) as jnp.where(idx, x + y, x)

Until that happens, scipy code will work with JAX, but it will break when jitted.

CC @jakevdp

@lucascolley
Copy link
Contributor

lucascolley commented Nov 26, 2024

this targets the wrong repo. array-api-compat is a compatibilty layer for things that are actually in the standard, or about to land there. array-api-extra is for things on top.

I think I disagree. Everything in array-api-extra should work with any standard-compatible library, whereas this implementation will not work with immutable arrays from libraries not covered by array-api-compat (there doesn't seem to be any way it could).

Meanwhile, none of the helper functions in array-api-compat (https://data-apis.org/array-api-compat/helper-functions.html) are in the standard1 - that's where this PR is targeting.

I suppose the worry is that this is effectively arguing for the standardisation of at (as @crusaderky has made explicit) by demonstrating that it can be implemented for all libraries currently supported by array-api-compat. At least, it seems to be requiring that array libraries give us some way of implementing it. I don't think it would be the end of the world for this to stay in array-api-compat though, rather than being pushed into the array libraries themselves, as long as there isn't a library for which support is too complicated.

Footnotes

  1. this might seem unfair since most of them are methods that have been standardised. In which case, this API should meet the bar for standardisation. But since it's not going into the standard, I think we can make an exception for concerns regarding adoption.

@lucascolley
Copy link
Contributor

Everything in array-api-extra should work with any standard-compatible library

I think there is room to make function behaviour vary on https://data-apis.org/array-api/2023.12/API_specification/generated/array_api.info.capabilities.html. But I don't think we should add this to array-api-extra without some sort of array-level flag for mutability in the standard (data-apis/array-api#845).

@asmeurer
Copy link
Member

I think it's a fuzzy zone whether this belongs in compat or in extra. The scope for array-api-compat is spelled out here https://data-apis.org/array-api-compat/#scope. On the one hand, the at helper is much like the device or size helpers already in compat. It converts what would normally be a method of an array object to a helper function, since methods cannot be wrapped. But you could also argue it's not quite the same as that since the "method" being wrapped here isn't really __setitem__ but rather .at, which isn't in the standard (yet). So one could also argue it's out of scope for array-api-compat. I think the biggest concern here is whether adding this here would be jumping the gun a bit for something that might later be added to the standard itself.

iwhere feels more like something that you could make this argument about, and I do wonder if it ought to be here.

is_writeable_array seems fine, although the main concern there as noted is that it's hard to make it generic against libraries that aren't covered by array-api-compat.

(by the way, we should add a reference to array-api-extra to that scope section)

@lucascolley
Copy link
Contributor

But you could also argue it's not quite the same as that since the "method" being wrapped here isn't really setitem but rather .at, which isn't in the standard (yet). So one could also argue it's out of scope for array-api-compat. I think the biggest concern here is whether adding this here would be jumping the gun a bit for something that might later be added to the standard itself.

IMO the bar is finding agreement on a sensible API in data-apis/array-api#609 (and here). But we don't need to have any confidence that existing array libraries will adopt it, since the implementations can live in array-api-compat if need be.

@crusaderky
Copy link
Contributor Author

iwhere feels more like something that you could make this argument about, and I do wonder if it ought to be here.

Note that I renamed it to where now. It's identical to the numpy one with the added parameter copy=True|False|None.
I'm not sure I will need it in scipy either - I suspect it would be redundant with the @jax.jit change I suggested above (#205 (comment)) - so I'm inclined to take it out of this PR and potentially reopen the discussion in a later PR if it turns out I do need it.

@crusaderky crusaderky force-pushed the jax branch 5 times, most recently from ca33a66 to c8f6613 Compare November 27, 2024 21:56
@crusaderky
Copy link
Contributor Author

I think this is ready for a second round of consultation.
As mentioned above, I've excised where/iwhere from this PR.

sparse

sparse is read-only and doesn't support .at or similar tools to my understanding, so none of it works at the moment. My opinion is that I should XFAIL everything now and then discuss implementing an .at method in the sparse library itself.

CC @hameerabbasi

JAX apply() and ufuncs

I can't seem to make JAX's apply() work:

>>> import jax.numpy as jnp
>>> import numpy as np
>>> a = jnp.array([1,2,3])
>>> a.at[:2].apply(np.negative)
jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape int32[]

This works without JIT...

>>> a.at[:2].set(np.negative(a[:2]))
Array([-1, -2,  3], dtype=int32)

... but it crashes with @jax.jit with the same error as above.
I think the giveaway is that intermediate output of the ufunc applied to the JAX array is a plain numpy array:

>>> np.negative(a)
array([-1, -2, -3], dtype=int32)

Given

  • the above problems, and
  • the fact that there is a straightforward workaround with set() that works as long as plain ufunc application is fixed in JAX, and
  • that we need to implement the workaround anyway for cupy, torch, and dask, as none of them implement ufunc.at, and
  • that the key feature of apply, that is the accumulation behaviour on repeated indices, won't be consistent anyway across backends, and
  • its limitation to unary ufuncs (unlike ufunc.at, which supports a second parameter),

I'm inclined to completely remove apply() from the API.

PyTorch ufuncs

Applying a ufunc to a torch array raises a warning, which makes the tests for min, max, and apply fail.
Should I suppress the warning in the test suite?

```python
>>> import torch
>>> import numpy
>>> a = torch.asarray([1,2,3])
>>> numpy.negative(a)
DeprecationWarning: __array_wrap__ must accept context and return_scalar arguments (positionally) in the future. (Deprecated NumPy 2.0)
tensor([-1, -2, -3])

tests/test_at.py Outdated Show resolved Hide resolved
tests/test_at.py Outdated Show resolved Hide resolved
@asmeurer
Copy link
Member

Looking a little closer at this, my personal opinion is that it would make more sense at the moment to put at in the array-api-extra library rather than here in array-api-compat. There is a lot more API surface to it than I had initially realized, i.e., all the methods on the object. array-api-extra is supposed to be a little looser on potential future API breakages (although we haven't completely figured out how that will work; I proposed making it vendor-only, but there are other ways to do this too).

It also looks like this is trying to use NumPy functions uniformly, regardless of whether the underlying library is NumPy. That's going to work for specific libraries that implement the NumPy interop (__array_function__ and so on), but it goes against the spirit of the array API. In particular, it limits how much this can be extended to other libraries beyond the "known" ones supported here. And this would never be part of the API if it were at some point standardized.

is_writable_array could make sense as a helper here. It would be nice if it at least optionally attempted to be generic for unknown libraries, but also it makes sense to not try to do that if it's too hard or has too many pitfalls. This definitely seems like an API that should be added to the array API (probably in the inspection namespace). I would at least propose it over on the array-api repo.

@asmeurer
Copy link
Member

Just to be clear, someone else (probably @rgommers) should make the actual decision on whether or not this belongs here or elsewhere.

@crusaderky
Copy link
Contributor Author

It also looks like this is trying to use NumPy functions uniformly, regardless of whether the underlying library is NumPy.

I've replaced np.minimum with xp.minimum - is this what you were referring to?

@asmeurer
Copy link
Member

Yes, that was what I meant. And also the apply function which you already mentioned you want to remove.

@jakevdp
Copy link
Contributor

jakevdp commented Nov 28, 2024

I can't seem to make JAX's apply() work:

In JAX, you can apply JAX functions, not NumPy functions:

>>> import jax.numpy as jnp
>>> a = jnp.array([1,2,3])
>>> a.at[:2].apply(jnp.negative)
Array([-1, -2,  3], dtype=int32)

That said, I don't think apply is all that important a function for this helper, so I wouldn't be opposed to leaving it out.

Same with the other snippet: it will work under JIT if you use jnp.negative rather than np.negative (so correctly-used xp.negative should be fine).

@crusaderky
Copy link
Contributor Author

crusaderky commented Dec 2, 2024

Removed at. Will open a PR in array-api-extra.

This PR is ready to be reviewed and merged.

@crusaderky crusaderky changed the title [WIP] Abstractions for read-only arrays Test for read-only arrays Dec 2, 2024
@crusaderky crusaderky closed this Dec 4, 2024
@crusaderky crusaderky reopened this Dec 4, 2024
@crusaderky crusaderky marked this pull request as ready for review December 6, 2024 17:14
crusaderky added a commit to crusaderky/array-api-compat that referenced this pull request Dec 6, 2024
@crusaderky crusaderky changed the title Test for read-only arrays ENH: Test for read-only arrays Dec 12, 2024
@crusaderky
Copy link
Contributor Author

This is ready for final review and merge.

array_api_compat/__init__.py Outdated Show resolved Hide resolved
https://numpy.org/neps/nep-0047-array-api-standard.html.
This is a small wrapper around NumPy, CuPy, JAX, sparse and others that is
compatible with the Array API standard https://data-apis.org/array-api/latest/.
See also NEP 47 https://numpy.org/neps/nep-0047-array-api-standard.html.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(unrelated to this PR, but probably no need to link to that NEP now)

@crusaderky
Copy link
Contributor Author

@lucascolley this is ready to be merged

@lucascolley
Copy link
Contributor

Sounds good, but I don't have merge permissions on this repo.

@ev-br
Copy link
Contributor

ev-br commented Dec 16, 2024

Would be helpful to summarize the current scope. So ATM, this PR only adds a single public function, in the vein of data-apis/array-api#845, correct?
Then it'd be helpful to spell the plan w.r.t. the spec: if the spec gets something in capabilities, will it obviate the need for this function, or will this function wrap it or....?

Copy link
Member

@rgommers rgommers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, seems ready to merge.

Then it'd be helpful to spell the plan w.r.t. the spec: if the spec gets something in capabilities, will it obviate the need for this function, or will this function wrap it or....?

I wouldn't worry about that here. The chance doesn't seem too high that the standard gets this (soon at least), and if it does then wrapping it if needed is straightforward. I'd merge this as is.

@crusaderky
Copy link
Contributor Author

Ready to merge

@ev-br ev-br merged commit cdd1c8d into data-apis:main Dec 16, 2024
42 checks passed
@ev-br
Copy link
Contributor

ev-br commented Dec 16, 2024

Thanks @crusaderky, all.

@crusaderky crusaderky deleted the jax branch December 17, 2024 08:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants