Skip to content

Commit 8e3fef1

Browse files
authored
Add anndata to gpu for dask (#312)
* update anndata to XPU for dask * update test
1 parent c6861aa commit 8e3fef1

File tree

4 files changed

+105
-0
lines changed

4 files changed

+105
-0
lines changed

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ markers = [
103103
[tool.hatch.build]
104104
# exclude big files that don’t need to be installed
105105
exclude = [
106+
"src/rapids_singlecell/_testing.py",
106107
"tests",
107108
"docs",
108109
"notebooks"

src/rapids_singlecell/_compat.py

+23
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
from __future__ import annotations
22

33
import cupy as cp
4+
import numpy as np
45
from cupyx.scipy.sparse import csr_matrix
56
from dask.array import Array as DaskArray # noqa: F401
7+
from scipy.sparse import csc_matrix as csc_matrix_cpu
8+
from scipy.sparse import csr_matrix as csr_matrix_cpu
69

710

811
def _meta_dense(dtype):
@@ -11,3 +14,23 @@ def _meta_dense(dtype):
1114

1215
def _meta_sparse(dtype):
1316
return csr_matrix(cp.array((1.0,), dtype=dtype))
17+
18+
19+
def _meta_dense(dtype):
20+
return cp.zeros([0], dtype=dtype)
21+
22+
23+
def _meta_sparse(dtype):
24+
return csr_matrix(cp.array((1.0,), dtype=dtype))
25+
26+
27+
def _meta_dense_cpu(dtype):
28+
return np.zeros([0], dtype=dtype)
29+
30+
31+
def _meta_sparse_csr_cpu(dtype):
32+
return csr_matrix_cpu(np.array((1.0,), dtype=dtype))
33+
34+
35+
def _meta_sparse_csc_cpu(dtype):
36+
return csc_matrix_cpu(np.array((1.0,), dtype=dtype))

src/rapids_singlecell/get/_anndata.py

+22
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,21 @@
77
import numpy as np
88
from cupyx.scipy.sparse import csc_matrix as csc_matrix_gpu
99
from cupyx.scipy.sparse import csr_matrix as csr_matrix_gpu
10+
from dask.array import Array as DaskArray
1011
from scanpy.get import _get_obs_rep, _set_obs_rep
1112
from scipy.sparse import csc_matrix as csc_matrix_cpu
1213
from scipy.sparse import csr_matrix as csr_matrix_cpu
1314
from scipy.sparse import isspmatrix_csc as isspmatrix_csc_cpu
1415
from scipy.sparse import isspmatrix_csr as isspmatrix_csr_cpu
1516

17+
from rapids_singlecell._compat import (
18+
_meta_dense,
19+
_meta_dense_cpu,
20+
_meta_sparse,
21+
_meta_sparse_csc_cpu,
22+
_meta_sparse_csr_cpu,
23+
)
24+
1625
if TYPE_CHECKING:
1726
from anndata import AnnData
1827

@@ -79,6 +88,11 @@ def X_to_GPU(X: CPU_ARRAY_TYPE, warning: str = "X") -> GPU_ARRAY_TYPE:
7988
"""
8089
if isinstance(X, GPU_ARRAY_TYPE):
8190
pass
91+
elif isinstance(X, DaskArray):
92+
if isinstance(X._meta, csc_matrix_cpu):
93+
pass
94+
meta = _meta_sparse if isinstance(X._meta, csr_matrix_cpu) else _meta_dense
95+
X = X.map_blocks(X_to_GPU, meta=meta(X.dtype))
8296
elif isspmatrix_csr_cpu(X):
8397
X = csr_matrix_gpu(X)
8498
elif isspmatrix_csc_cpu(X):
@@ -146,6 +160,14 @@ def X_to_CPU(X: GPU_ARRAY_TYPE) -> CPU_ARRAY_TYPE:
146160
X
147161
Matrix or array to transfer to the host memory
148162
"""
163+
if isinstance(X, DaskArray):
164+
if isinstance(X._meta, csr_matrix_gpu):
165+
meta = _meta_sparse_csr_cpu
166+
elif isinstance(X._meta, csc_matrix_gpu):
167+
meta = _meta_sparse_csc_cpu
168+
else:
169+
meta = _meta_dense_cpu
170+
X = X.map_blocks(X_to_GPU, meta=meta(X.dtype))
149171
if isinstance(X, GPU_ARRAY_TYPE):
150172
X = X.get()
151173
else:

tests/dask/test_get.py

+59
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from __future__ import annotations
2+
3+
import cupy as cp
4+
import numpy as np
5+
import pytest
6+
from scanpy.datasets import pbmc3k_processed
7+
from scipy import sparse
8+
9+
import rapids_singlecell as rsc
10+
from rapids_singlecell._testing import (
11+
as_dense_cupy_dask_array,
12+
as_sparse_cupy_dask_array,
13+
)
14+
15+
16+
@pytest.mark.parametrize("data_kind", ["sparse", "dense"])
17+
def test_get_anndata(client, data_kind):
18+
adata = pbmc3k_processed()
19+
dask_adata = adata.copy()
20+
if data_kind == "sparse":
21+
adata.X = rsc.get.X_to_GPU(sparse.csr_matrix(adata.X.astype(np.float64)))
22+
dask_adata.X = as_sparse_cupy_dask_array(dask_adata.X.astype(np.float64))
23+
elif data_kind == "dense":
24+
adata.X = cp.array(adata.X.astype(np.float64))
25+
dask_adata.X = as_dense_cupy_dask_array(dask_adata.X.astype(np.float64))
26+
else:
27+
raise ValueError(f"Unknown data_kind {data_kind}")
28+
29+
assert type(adata.X) is type(dask_adata.X._meta)
30+
31+
if data_kind == "sparse":
32+
cp.testing.assert_array_equal(
33+
adata.X.toarray(), dask_adata.X.compute().toarray()
34+
)
35+
else:
36+
cp.testing.assert_array_equal(adata.X, dask_adata.X.compute())
37+
38+
rsc.get.anndata_to_CPU(dask_adata)
39+
rsc.get.anndata_to_CPU(adata)
40+
41+
assert type(adata.X) is type(dask_adata.X._meta)
42+
43+
if data_kind == "sparse":
44+
cp.testing.assert_array_equal(
45+
adata.X.toarray(), dask_adata.X.compute().toarray()
46+
)
47+
else:
48+
cp.testing.assert_array_equal(adata.X, dask_adata.X.compute())
49+
rsc.get.anndata_to_GPU(dask_adata)
50+
rsc.get.anndata_to_GPU(adata)
51+
52+
assert type(adata.X) is type(dask_adata.X._meta)
53+
54+
if data_kind == "sparse":
55+
cp.testing.assert_array_equal(
56+
adata.X.toarray(), dask_adata.X.compute().toarray()
57+
)
58+
else:
59+
cp.testing.assert_array_equal(adata.X, dask_adata.X.compute())

0 commit comments

Comments
 (0)