diff --git a/scar/main/_setup.py b/scar/main/_setup.py index bb6a7f4..0f1f591 100644 --- a/scar/main/_setup.py +++ b/scar/main/_setup.py @@ -92,7 +92,6 @@ def setup_anndata( raw_adata._inplace_subset_obs(raw_adata.X.sum(axis=1) >= min_raw_counts) raw_adata.obs["total_counts"] = raw_adata.X.sum(axis=1) - raw_count = raw_adata.X.astype(int).A # initial estimation of ambient profile, will be update ambient_prof = raw_adata.X.sum(axis=0) / raw_adata.X.sum() @@ -105,9 +104,16 @@ def setup_anndata( # calculate joint probability (log) of being cell-free droplets for each droplet log_prob = [] - batches = np.array_split(raw_count, n_batch) + batch_idx = np.floor( + np.array(range(raw_adata.shape[0])) / raw_adata.shape[0] * n_batch + ) + + # batches = np.array_split(raw_count, n_batch) for b in range(n_batch): - count_batch = batches[b] + try: + count_batch = raw_adata[batch_idx == b].X.astype(int).A + except MemoryError: + raise MemoryError("use more batches by setting a higher n_batch") log_prob_batch = Multinomial( probs=torch.tensor(ambient_prof), validate_args=False ).log_prob(torch.Tensor(count_batch))