-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RFC: add "mutable arrays"
to capabilities
#845
Comments
(to anticipate one response: no, it's not possible to make JAX arrays support mutation: central to JAX are transformations like |
For (3), could you prototype what it would look like in the case of gh-609? For |
This is very surprising. It would be nice if we can have a list of such occurrences here, because this was not supposed to happen as per our design guideline https://data-apis.org/array-api/latest/design_topics/copies_views_and_mutation.html |
The main example is
For example, it could look something like this in the specific case of updating an array with a mask and a scalar: info = xp.__array_namespace_info__().capabilities()
if info['mutable arrays']:
x[xp.isnan(x)] = 0
else:
x = xp.where(xp.isnan(x), 0, x) That would certainly not cover all cases, but it would be enough to fix a large number of the incompatibilities being currently introduced into scipy and scikit-learn. But in general, yes, it would also be beneficial if the array API standard could add some syntax for out-of-place array updates similar to what's being discussed in #609. |
"mutable arrays"
to capabilities
"mutable arrays"
to capabilities
Adding "mutable arrays" to library level |
Agreed with @pearu's comment. There are multiple other issues here though, for example: (1) What does it mean to be a "mutable array"? To stay with the >>> import numpy as np
>>> x = np.arange(5)
>>> y = x[:3]
>>> y.flags.writeable = False
>>> y += 1
...
ValueError: output array is read-only
>>> y[0]
np.int64(0)
>>> x += 1
>>> y[0]
np.int64(1) So is (2) JAX you'd argue is immutable I'm sure, however as we saw in the example above numpy readonly arrays reject in-place operators like >>> import jax.numpy as jnp
>>> x = jnp.arange(5)
>>> x[0]
Array(0, dtype=int32)
>>> x += 1
>>> x[0]
Array(1, dtype=int32) So I'd say "is a mutable array" is quite ambiguous.
I think that we should not add if is_jax(x):
x[xp.isnan(x)] = 0
else:
x = xp.where(xp.isnan(x), 0, x) |
Yes,
In general, there exists always ways to mutate data of immutable objects. One can even mutate JAX arrays easily via dlpack or array interface protocols. In this specific case, the example demonstrates a common practice of viewing data as read-only while the data could still be modified at some other level or time. For instance, one can open a file in read-only mode and in this context the file descriptor would represent an immutable object while some other process may open the same file in a writable mode which would enable mutations.
This means that numpy and JAX implement different semantics for in-place operations: for numpy, in-place operation is a mutable operation while for JAX, the in-place operation is a syntactic sugar for transformations:
I'd disagree. By definition, an "array" is a certain view of (contiguous) data that elements can be accessed via indexing operation.
You probably meant to write: if not is_jax(x):
x[xp.isnan(x)] = 0
else:
x = xp.where(xp.isnan(x), 0, x) I find using "jax" in the name of utility predicate function suboptimal because JAX arrays are not the only array objects that are immutable. So, instead of introducing if is_mutable(x):
x[xp.isnan(x)] = 0
else:
x = xp.where(xp.isnan(x), 0, x) so that the same code in scipy/... will not be needed to be modified when one invents another Array API compliant array object that is immutable: it will be sufficient to update only the definition of |
The problem is that we now need to fetch a capability of the array, rather than the array namespace, since NumPy has both behaviours. So what way could we address this if not for an |
Thinking about this a bit, I think the language about "views" in https://data-apis.org/array-api/latest/design_topics/copies_views_and_mutation.html is not quite strong enough. Let's assume "in-place update" is equivalent to The kind of example I have in mind is this: x = xp.zeros(2)
L = [x]
x[0] = 1
print(L) What is the result here? If So the equivalence of in-place and out-of-place semantics doesn't just require the absence of array views in the sense of what's tracked by |
Recall, Python x[0] = 1 uses JAX's semantics ( When modifying the @jakevdp example as follows: x = xp.zeros(2)
L = [x]
L[0][0] = 1
print(L) the expected output would be |
Sorry, I think you misunderstood my point. I wasn't arguing that |
Ok, fair enough. A better example would be that uses, say, some in-place operation ( |
Rather than getting lost in implementation details, let's bring it back to the statement I was responding to:
I think this is untrue, unless you also consider Python-level references as well as views when reasoning about whether an operation affects a "single array". And limiting operations to objects with a refcount of 1 is far more intrusive than limiting operations to arrays whose buffer is not shared with any other array objects. |
Yes, that is a good point, and I agree that that page should be more explicit and grow a section on Python refcount >1. The behavior difference applies not only to
Agreed. I think (but am not sure, have to give it some more thought) is that that should remain undefined behavior - it's kinda baked into the Python language, and it's already a difference today between JAX and NumPy/PyTorch today for |
Hey all – we chatted about this in today's meeting, and here is a summary:
|
Can I suggest re-opening this issue for visibility? scipy/scipy#22049 describes another use-case for an array-level mutability flag. We would like to make a copy after using |
Several parts of the Array API standard assume that array objects are mutable. Some array API implementations (notably JAX) do not support mutating array objects. This has led to array API implementations currently being developed in
scipy
andsklearn
to be entirely unusable in JAX.Given this, downstream implementations have a few choices:
jax.numpy.Array
, changing the implementation logic for that case.(1) is a bad choice, because it means JAX will not be supported. (2) is a bad choice, because for libraries like NumPy, it leads to excessive copying of buffers, worsening performance. (3) is a bad choice because it hard-codes the presence of specific implementations in a context that is supposed to be implementation-agnostic.
One way the Array API standard could address this is by adding
"mutable arrays"
or something similar to the existingcapabilities
dict. Then downstream implementations could use strategy (3) without special-casing particular implementations.The text was updated successfully, but these errors were encountered: