Skip to content

Commit

Permalink
Backport PR #1790 on branch 0.11.x ((performance): speedup for zarr
Browse files Browse the repository at this point in the history
…-based sparse indexing) (#1801)

Co-authored-by: Ilan Gold <[email protected]>
  • Loading branch information
meeseeksmachine and ilan-gold authored Dec 10, 2024
1 parent f412fd3 commit bd93c32
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 15 deletions.
1 change: 1 addition & 0 deletions docs/release-notes/1790.performance.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Batch slice-based indexing in {class}`anndata.abc.CSRDataset` and {class}`anndata.abc.CSCDataset` for performance boost in `zarr` {user}`ilan-gold`
42 changes: 27 additions & 15 deletions src/anndata/_core/sparse_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from abc import ABC
from collections.abc import Iterable
from functools import cached_property
from itertools import accumulate, chain
from itertools import accumulate, chain, pairwise
from math import floor
from pathlib import Path
from typing import TYPE_CHECKING, NamedTuple
Expand Down Expand Up @@ -250,30 +250,42 @@ def slice_as_int(s: slice, l: int) -> int:
def get_compressed_vectors(
x: BackedSparseMatrix, row_idxs: Iterable[int]
) -> tuple[Sequence, Sequence, Sequence]:
slices = [slice(*(x.indptr[i : i + 2])) for i in row_idxs]
data = np.concatenate([x.data[s] for s in slices])
indices = np.concatenate([x.indices[s] for s in slices])
indptr = list(accumulate(chain((0,), (s.stop - s.start for s in slices))))
indptr_slices = [slice(*(x.indptr[i : i + 2])) for i in row_idxs]
# HDF5 cannot handle out-of-order integer indexing
if isinstance(x.data, ZarrArray):
as_np_indptr = np.concatenate(
[np.arange(s.start, s.stop) for s in indptr_slices]
)
data = x.data[as_np_indptr]
indices = x.indices[as_np_indptr]
else:
data = np.concatenate([x.data[s] for s in indptr_slices])
indices = np.concatenate([x.indices[s] for s in indptr_slices])
indptr = list(accumulate(chain((0,), (s.stop - s.start for s in indptr_slices))))
return data, indices, indptr


def get_compressed_vectors_for_slices(
x: BackedSparseMatrix, slices: Iterable[slice]
) -> tuple[Sequence, Sequence, Sequence]:
indptr_sels = [x.indptr[slice(s.start, s.stop + 1)] for s in slices]
data = np.concatenate([x.data[s[0] : s[-1]] for s in indptr_sels])
indices = np.concatenate([x.indices[s[0] : s[-1]] for s in indptr_sels])
indptr_indices = [x.indptr[slice(s.start, s.stop + 1)] for s in slices]
indptr_limits = [slice(i[0], i[-1]) for i in indptr_indices]
# HDF5 cannot handle out-of-order integer indexing
if isinstance(x.data, ZarrArray):
indptr_int = np.concatenate([np.arange(s.start, s.stop) for s in indptr_limits])
data = x.data[indptr_int]
indices = x.indices[indptr_int]
else:
data = np.concatenate([x.data[s] for s in indptr_limits])
indices = np.concatenate([x.indices[s] for s in indptr_limits])
# Need to track the size of the gaps in the slices to each indptr subselection
total = indptr_sels[0][0]
offsets = [total]
for i, sel in enumerate(indptr_sels[1:]):
total = (sel[0] - indptr_sels[i][-1]) + total
offsets.append(total)
start_indptr = indptr_sels[0] - offsets[0]
gaps = (s1.start - s0.stop for s0, s1 in pairwise(indptr_limits))
offsets = accumulate(chain([indptr_limits[0].start], gaps))
start_indptr = indptr_indices[0] - next(offsets)
if len(slices) < 2: # there is only one slice so no need to concatenate
return data, indices, start_indptr
end_indptr = np.concatenate(
[s[1:] - offsets[i + 1] for i, s in enumerate(indptr_sels[1:])]
[s[1:] - o for s, o in zip(indptr_indices[1:], offsets)]
)
indptr = np.concatenate([start_indptr, end_indptr])
return data, indices, indptr
Expand Down

0 comments on commit bd93c32

Please sign in to comment.