Skip to content

Commit

Permalink
Backport PR #1780 on branch 0.11.x ((fix): use dask array for missing…
Browse files Browse the repository at this point in the history
… element in dask concatenation) (#1800)

Co-authored-by: Ilan Gold <[email protected]>
  • Loading branch information
meeseeksmachine and ilan-gold authored Dec 10, 2024
1 parent ba06d2a commit f412fd3
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 8 deletions.
36 changes: 28 additions & 8 deletions src/anndata/_core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,19 +939,37 @@ def gen_outer_reindexers(els, shapes, new_index: pd.Index, *, axis=0):
return reindexers


def missing_element(
n: int,
els: list[SpArray | sparse.csr_matrix | sparse.csc_matrix | np.ndarray | DaskArray],
axis: Literal[0, 1] = 0,
fill_value: Any | None = None,
) -> np.ndarray | DaskArray:
"""Generates value to use when there is a missing element."""
should_return_dask = any(isinstance(el, DaskArray) for el in els)
try:
non_missing_elem = next(el for el in els if not_missing(el))
except StopIteration: # pragma: no cover
msg = "All elements are missing when attempting to generate missing elements."
raise ValueError(msg)
# 0 sized array for in-memory prevents allocating unnecessary memory while preserving broadcasting.
off_axis_size = 0 if not should_return_dask else non_missing_elem.shape[axis - 1]
shape = (n, off_axis_size) if axis == 0 else (off_axis_size, n)
if should_return_dask:
import dask.array as da

return da.full(
shape, default_fill_value(els) if fill_value is None else fill_value
)
return np.zeros(shape, dtype=bool)


def outer_concat_aligned_mapping(
mappings, *, reindexers=None, index=None, axis=0, fill_value=None
):
result = {}
ns = [m.parent.shape[axis] for m in mappings]

def missing_element(n: int, axis: Literal[0, 1] = 0) -> np.ndarray:
"""Generates value to use when there is a missing element."""
if axis == 0:
return np.zeros((n, 0), dtype=bool)
else:
return np.zeros((0, n), dtype=bool)

for k in union_keys(mappings):
els = [m.get(k, MissingVal) for m in mappings]
if reindexers is None:
Expand All @@ -963,7 +981,9 @@ def missing_element(n: int, axis: Literal[0, 1] = 0) -> np.ndarray:
# We should probably just handle missing elements for all types
result[k] = concat_arrays(
[
el if not_missing(el) else missing_element(n, axis=axis)
el
if not_missing(el)
else missing_element(n, axis=axis, els=els, fill_value=fill_value)
for el, n in zip(els, ns)
],
cur_reindexers,
Expand Down
28 changes: 28 additions & 0 deletions tests/test_concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1533,6 +1533,34 @@ def test_concat_different_types_dask(merge_strategy, array_type):
assert_equal(result2, target2)


def test_concat_missing_elem_dask_join(join_type):
import dask.array as da

import anndata as ad

ad1 = ad.AnnData(X=np.ones((5, 5)))
ad2 = ad.AnnData(X=np.zeros((5, 5)), layers={"a": da.ones((5, 5))})
ad_in_memory_with_layers = ad2.to_memory()

result1 = ad.concat([ad1, ad2], join=join_type)
result2 = ad.concat([ad1, ad_in_memory_with_layers], join=join_type)
assert_equal(result1, result2)


def test_impute_dask(axis_name):
import dask.array as da

from anndata._core.merge import _resolve_axis, missing_element

axis, _ = _resolve_axis(axis_name)
els = [da.ones((5, 5))]
missing = missing_element(6, els, axis=axis)
assert isinstance(missing, DaskArray)
in_memory = missing.compute()
assert np.all(np.isnan(in_memory))
assert in_memory.shape[axis] == 6


def test_outer_concat_with_missing_value_for_df():
# https://github.com/scverse/anndata/issues/901
# TODO: Extend this test to cover all cases of missing values
Expand Down

0 comments on commit f412fd3

Please sign in to comment.