-
Notifications
You must be signed in to change notification settings - Fork 22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multi-GPU support with dask #179
base: main
Are you sure you want to change the base?
Conversation
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From what I can tell, also CSC is just not really mentioned? It's not supported but maybe we should throw an error or something?
@@ -37,6 +37,7 @@ doc = [ | |||
"scanpydoc[typehints,theme]>=0.9.4", | |||
"readthedocs-sphinx-ext", | |||
"sphinx_copybutton", | |||
"dask", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These should be added to the dependencies.
adata_subset = adata[adata.obs[batch_key] == batch].copy() | ||
|
||
calculate_qc_metrics(adata_subset, layer=layer) | ||
filt = adata_subset.var["n_cells_by_counts"].to_numpy() > 0 | ||
adata_subset = adata_subset[:, filt] | ||
adata_subset = adata_subset[:, filt].copy() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this is unrelated to this PR?
adata_subset = adata[adata.obs[batch_key] == batch].copy() | ||
|
||
calculate_qc_metrics(adata_subset, layer=layer) | ||
filt = adata_subset.var["n_cells_by_counts"].to_numpy() > 0 | ||
adata_subset = adata_subset[:, filt] | ||
adata_subset = adata_subset[:, filt].copy() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a test for the "weird stuff?"
""" | ||
|
||
|
||
def _sparse_qc_csr_dask_cells(dtype): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I 100% agree with this. If it is a separate PR, that's fine by me. But we need to be able to maintain this library too if a fix needs to be made or something. This sort of thing goes a long way towards making things easier for the next person.
|
||
def _normalize_total(X: ArrayTypesDask, target_sum: int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we just completely punting on CSC matrices?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For csc we transform to csr in normalize and thats always done.
from ._kernels._norm_kernel import _mul_csr | ||
|
||
mul_kernel = _mul_csr(X.dtype) | ||
mul_kernel.compile() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with Phil here. Calling compile should be universally applied on first-access as neeeded instead of manually having to remember to do it. We should be trying to make sure that if someone wants to add a new feature to RSC and forgets to call compile, that can't happen if you're out-of-commission.
def _get_target_sum_csr(X: sparse.csr_matrix) -> int: | ||
from ._kernels._norm_kernel import _get_sparse_sum_major | ||
|
||
counts_per_cell = cp.zeros(X.shape[0], dtype=X.dtype) | ||
sum_kernel = _get_sparse_sum_major(X.dtype) | ||
sum_kernel( | ||
(X.shape[0],), | ||
(64,), | ||
(X.indptr, X.data, counts_per_cell, X.shape[0]), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please refactor these lines into their own function that can then be applied to each block of the dask array and then do the masking + median at the end after either you do map_blocks
or _get_target_sum_csr
. We do this "recursive" map_blocks
in scanpy
a lot and it works very well and makes things very easy to reason about. These functions are basically identical. Especially if compiling really costs nothing extra as you say, then this should be doable. Maybe I'm missing something though. This seems like a good point in favor of "add compile to some always-called function" as Phil was saying
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see where you are coming from with this one. However I feel the compile might make this challenging. I have an idea on how to maybe fix this but this needs some testing and investigations. I could wrap the cuda kernel factory to call compile on the returned kernel. But I don't know if this than executed on the worker or host.
chunks=(X.chunksize[0],), | ||
drop_axis=1, | ||
) | ||
counts_per_cell = target_sum_chunk_matrices.compute() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why? Too much for this PR? Make an issue maybe?
svd_solver = "jacobi" | ||
pca_func = PCA(n_components=n_comps, svd_solver=svd_solver, whiten=False) | ||
X_pca = pca_func.fit_transform(X) | ||
X_pca = X_pca.compute_chunk_sizes() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to compute chunk sizes here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a comment for PCA X_pca = X_pca.compute_chunk_sizes()
.
For the copying in HVG. Since I use calculate_qc for to find where to subset. I wanted to be sure that the original data doesn't get overwritten. I addition to that the dask gpu views are sometimes a bit weird. So copy makes this more solid in general. That applies most to slicing against the minor axis.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also dont want dask to be a dependency because cuml handels that. I dont want get into problems there.
Co-authored-by: Ilan Gold <[email protected]>
I renamed the functions for QC and renamed some of the variables so its a bit clearer whats happening. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://github.com/scverse/rapids_singlecell/pull/179/files#r1838498091 is not done and from what I can tell #179 (review) has not been addressed. What happens if you pass a csc dask array to pca?
That will just error. And tell the user to please give me dense or csr as meta. I updated _check_gpu_X to reflect that. The median I'll test today |
We should look into the cost of allocating ahead of time for all operations that are currently in-place |
Median out of core is a bad choice. Uses way more memory and is slower. Loose Loose |
This adds dask support
Functions to add:
calculate_qc_metrics
normalize_total
log1p
highly_variable_genes
withseurat
andcell_ranger
scale
PCA
neighbors