Skip to content

Commit

Permalink
merged with main
Browse files Browse the repository at this point in the history
  • Loading branch information
anushka255 committed Nov 7, 2024
2 parents 34485d2 + feae479 commit 6a0c02f
Show file tree
Hide file tree
Showing 16 changed files with 663 additions and 171 deletions.
7 changes: 6 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ dependencies = [
"scikit-learn",
"IPython",
"statsmodels",
"torch"
"torch",
"torchnmf"
]
dynamic = ["version"]

Expand All @@ -41,6 +42,7 @@ dev = [
"coveralls",
"ruff",
"pre-commit",
"napari"
]

docs = [
Expand All @@ -56,6 +58,9 @@ docs = [
"squidpy"
]

[project.entry-points."napari.manifest"]
paste3 = "paste3.napari:napari.yaml"

[tool.setuptools]
package-dir = {"" = "src"}

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ stdlib-list==0.10.0
sympy==1.13.3
threadpoolctl==3.5.0
torch==2.4.1
torchnmf==0.3.5
tqdm==4.66.5
traitlets==5.14.3
typing_extensions==4.12.2
Expand Down
150 changes: 22 additions & 128 deletions scripts/workflow.py
Original file line number Diff line number Diff line change
@@ -1,140 +1,34 @@
import logging
from pathlib import Path

import numpy as np
import scanpy as sc
from anndata import AnnData

from paste3.helper import match_spots_using_spatial_heuristic
from paste3.paste import center_align, pairwise_align
from paste3.visualization import stack_slices_center, stack_slices_pairwise
from paste3.experimental import AlignmentDataset

logger = logging.getLogger(__name__)


class Slice:
def __init__(self, filepath: Path | None = None, adata: AnnData | None = None):
if adata is None:
self.adata = sc.read_h5ad(filepath)
else:
self.adata = adata

def __str__(self):
return f"Slice {self.adata}"


class AlignmentDataset:
@staticmethod
def from_csvs(gene_expression_csvs: list[Path], coordinate_csvs: list[Path]):
pass

def __init__(
self,
data_dir: Path | None = None,
slices: list[Slice] | None = None,
max_slices: int | None = None,
):
if slices is not None:
self.slices = slices[:max_slices]
else:
self.slices = [
Slice(filepath)
for filepath in sorted(Path(data_dir).glob("*.h5ad"))[:max_slices]
]

def __str__(self):
return f"Data with {len(self.slices)} slices"

def __iter__(self):
return iter(self.slices)

def __len__(self):
return len(self.slices)

@property
def slices_adata(self) -> list[AnnData]:
return [slice_.adata for slice_ in self.slices]

def align(
self,
center_align: bool = False,
center_slice: Slice | None = None,
pis: np.ndarray | None = None,
overlap_fraction: float | None = None,
max_iters: int = 1000,
):
if center_align:
if overlap_fraction is not None:
logger.warning(
"Ignoring overlap_fraction argument (unsupported in center_align mode)"
)
return self.center_align(center_slice, pis)
assert overlap_fraction is not None, "overlap_fraction must be specified"
return self.pairwise_align(
overlap_fraction=overlap_fraction, pis=pis, max_iters=max_iters
)

def find_pis(self, overlap_fraction: float, max_iters: int = 1000):
pis = []
for i in range(len(self) - 1):
logger.info(f"Finding Pi for slices {i} and {i+1}")
pis.append(
pairwise_align(
self.slices[i].adata,
self.slices[i + 1].adata,
overlap_fraction=overlap_fraction,
numItermax=max_iters,
maxIter=max_iters,
)
)
return pis

def pairwise_align(
self,
overlap_fraction: float,
pis: list[np.ndarray] | None = None,
max_iters: int = 1000,
):
if pis is None:
pis = self.find_pis(overlap_fraction=overlap_fraction, max_iters=max_iters)
new_slices = stack_slices_pairwise(self.slices_adata, pis)
return AlignmentDataset(slices=[Slice(adata=s) for s in new_slices])
if __name__ == "__main__":
dataset = AlignmentDataset(
"/home/vineetb/paste3/paste_reproducibility/data/SCC/cached-results/H5ADs/patient_2*"
)

def find_center_slice(
self, reference_slice: Slice | None = None, pis: np.ndarray | None = None
) -> tuple[Slice, list[np.ndarray]]:
if reference_slice is None:
reference_slice = self.slices[0]
center_slice, pis = center_align(
reference_slice.adata, self.slices_adata, pi_inits=pis
)
return Slice(adata=center_slice), pis
all_points_orig = dataset.all_points()

def find_pis_init(self) -> list[np.ndarray]:
reference_slice = self.slices[0]
return [
match_spots_using_spatial_heuristic(reference_slice.adata.X, slice_.adata.X)
for slice_ in self.slices
]
cluster_indices = set()
for slice in dataset.slices:
clusters = set(slice.get_obs_values("original_clusters"))
cluster_indices |= clusters
n_clusters = len(cluster_indices)

def center_align(
self,
reference_slice: Slice | None = None,
pis: list[np.ndarray] | None = None,
):
if reference_slice is None:
reference_slice, pis = self.find_center_slice(pis=pis)
else:
pis = self.find_pis_init()
# ------- Center Align ------- #
center_slice, pis = dataset.find_center_slice()
aligned_dataset = dataset.center_align(center_slice=center_slice, pis=pis)
all_points = aligned_dataset.all_points()

_, new_slices = stack_slices_center(
center_slice=reference_slice.adata, slices=self.slices_adata, pis=pis
)
return AlignmentDataset(slices=[Slice(adata=s) for s in new_slices])
center_slice.cluster(n_clusters, save_as="new_clusters")

new_clusters = center_slice.get_obs_values("new_clusters")
# ------- Center Align ------- #

if __name__ == "__main__":
dataset = AlignmentDataset("data/", max_slices=3)
aligned_dataset = dataset.align(
center_align=False, overlap_fraction=0.7, max_iters=2
)
# ------- Pairwise Align ------- #
aligned_dataset = dataset.pairwise_align(overlap_fraction=0.7)
all_points = aligned_dataset.all_points()
# ------- Pairwise Align ------- #
Loading

0 comments on commit 6a0c02f

Please # to comment.