diff --git a/docs/source/notebooks/paste2_tutorial.ipynb b/docs/source/notebooks/paste2_tutorial.ipynb index 14b1ba9..a4a95a4 100644 --- a/docs/source/notebooks/paste2_tutorial.ipynb +++ b/docs/source/notebooks/paste2_tutorial.ipynb @@ -168,7 +168,7 @@ "metadata": {}, "outputs": [], "source": [ - "pi_AB = paste.pairwise_align(sliceA, sliceB, overlap_fraction=0.7, maxIter=20)" + "pi_AB, _ = paste.pairwise_align(sliceA, sliceB, overlap_fraction=0.7, maxIter=20)" ] }, { @@ -178,7 +178,7 @@ "metadata": {}, "outputs": [], "source": [ - "pi_BC = paste.pairwise_align(sliceB, sliceC, overlap_fraction=0.7, maxIter=20)" + "pi_BC, _ = paste.pairwise_align(sliceB, sliceC, overlap_fraction=0.7, maxIter=20)" ] }, { @@ -188,7 +188,7 @@ "metadata": {}, "outputs": [], "source": [ - "pi_CD = paste.pairwise_align(sliceC, sliceD, overlap_fraction=0.7, maxIter=20)" + "pi_CD, _ = paste.pairwise_align(sliceC, sliceD, overlap_fraction=0.7, maxIter=20)" ] }, { @@ -335,7 +335,7 @@ "pis = [pi_AB, pi_BC, pi_CD]\n", "slices = [sliceA, sliceB, sliceC, sliceD]\n", "\n", - "new_slices = visualization.stack_slices_pairwise(slices, pis, is_partial=True)" + "new_slices, _, _ = visualization.stack_slices_pairwise(slices, pis, is_partial=True)" ] }, { diff --git a/docs/source/notebooks/paste_tutorial.ipynb b/docs/source/notebooks/paste_tutorial.ipynb index 21c0311..ae50431 100644 --- a/docs/source/notebooks/paste_tutorial.ipynb +++ b/docs/source/notebooks/paste_tutorial.ipynb @@ -36,7 +36,7 @@ "import scanpy as sc\n", "import torch\n", "\n", - "from paste3.helper import filter_for_common_genes, match_spots_using_spatial_heuristic\n", + "from paste3.helper import get_common_genes, match_spots_using_spatial_heuristic\n", "from paste3.paste import center_align, pairwise_align\n", "from paste3.visualization import plot_slice, stack_slices_center, stack_slices_pairwise" ] @@ -186,9 +186,9 @@ "source": [ "start = time.time()\n", "\n", - "pi12 = pairwise_align(slice1, slice2)\n", - "pi23 = pairwise_align(slice2, slice3)\n", - "pi34 = pairwise_align(slice3, slice4)\n", + "pi12, _ = pairwise_align(slice1, slice2)\n", + "pi23, _ = pairwise_align(slice2, slice3)\n", + "pi34, _ = pairwise_align(slice3, slice4)\n", "\n", "print(\"Runtime: \" + str(time.time() - start))" ] @@ -218,7 +218,7 @@ "pis = [pi12, pi23, pi34]\n", "slices = [slice1, slice2, slice3, slice4]\n", "\n", - "new_slices = stack_slices_pairwise(slices, pis)" + "new_slices, _, _ = stack_slices_pairwise(slices, pis)" ] }, { @@ -368,7 +368,7 @@ "metadata": {}, "outputs": [], "source": [ - "filter_for_common_genes(slices)\n", + "slices, _ = get_common_genes(slices)\n", "\n", "b = []\n", "for i in range(len(slices)):\n", @@ -455,7 +455,7 @@ "metadata": {}, "outputs": [], "source": [ - "center, new_slices = stack_slices_center(center_slice, slices, pis)" + "center, new_slices, _, _ = stack_slices_center(center_slice, slices, pis)" ] }, { @@ -645,9 +645,9 @@ "source": [ "start = time.time()\n", "\n", - "pi12 = pairwise_align(slice1, slice2, use_gpu=True)\n", - "pi23 = pairwise_align(slice2, slice3, use_gpu=True)\n", - "pi34 = pairwise_align(slice3, slice4, use_gpu=True)\n", + "pi12, _ = pairwise_align(slice1, slice2, use_gpu=True)\n", + "pi23, _ = pairwise_align(slice2, slice3, use_gpu=True)\n", + "pi34, _ = pairwise_align(slice3, slice4, use_gpu=True)\n", "\n", "print(\"Runtime: \" + str(time.time() - start))" ] diff --git a/src/paste3/align.py b/src/paste3/align.py index 0ba2bb4..b98b6fa 100644 --- a/src/paste3/align.py +++ b/src/paste3/align.py @@ -32,7 +32,6 @@ def align( norm=False, numItermax=200, use_gpu=True, - return_obj=False, optimizeTheta=True, eps=1e-4, is_histology=False, @@ -82,7 +81,7 @@ def align( logger.info("Computing Pairwise Alignment ") pis = [] for i in range(n_slices - 1): - pi = pairwise_align( + pi, _ = pairwise_align( a_slice=slices[i], b_slice=slices[i + 1], overlap_fraction=overlap_fraction, @@ -95,7 +94,6 @@ def align( norm=norm, numItermax=numItermax, use_gpu=use_gpu, - return_obj=return_obj, maxIter=max_iter, optimizeTheta=optimizeTheta, eps=eps, @@ -245,6 +243,5 @@ def main(args): norm=args.norm, numItermax=args.max_iter, use_gpu=args.gpu, - return_obj=args.r_info, is_histology=args.hist, ) diff --git a/src/paste3/dataset.py b/src/paste3/dataset.py index d688c1e..d47bd0c 100644 --- a/src/paste3/dataset.py +++ b/src/paste3/dataset.py @@ -160,15 +160,14 @@ def find_pis(self, overlap_fraction: float | list[float], max_iters: int = 1000) pis = [] for i in range(len(self) - 1): logger.info(f"Finding Pi for slices {i} and {i+1}") - pis.append( - pairwise_align( - self.slices[i].adata, - self.slices[i + 1].adata, - overlap_fraction=overlap_fraction[i], - numItermax=max_iters, - maxIter=max_iters, - ) + pi, _ = pairwise_align( + self.slices[i].adata, + self.slices[i + 1].adata, + overlap_fraction=overlap_fraction[i], + numItermax=max_iters, + maxIter=max_iters, ) + pis.append(pi) return pis def pairwise_align( @@ -180,7 +179,7 @@ def pairwise_align( if pis is None: pis = self.find_pis(overlap_fraction=overlap_fraction, max_iters=max_iters) new_slices, rotation_angles, translations = stack_slices_pairwise( - self.slices_adata, pis, return_params=True + self.slices_adata, pis ) aligned_dataset = AlignmentDataset( slices=[Slice(adata=s) for s in new_slices], @@ -241,7 +240,6 @@ def center_align( center_slice=center_slice.adata, slices=self.slices_adata, pis=pis, - output_params=True, ) aligned_dataset = AlignmentDataset( slices=[Slice(adata=s) for s in new_slices], diff --git a/src/paste3/helper.py b/src/paste3/helper.py index 5e80301..accefe0 100644 --- a/src/paste3/helper.py +++ b/src/paste3/helper.py @@ -20,9 +20,10 @@ logger = logging.getLogger(__name__) -def kl_divergence(a_exp_dissim, b_exp_dissim): +def kl_divergence(a_exp_dissim, b_exp_dissim, is_generalized=False): """ - Calculates the Kullback-Leibler divergence between two distributions. + Calculates the Kullback-Leibler divergence (KL) or generalized KL + divergence between two distributions. Parameters ---------- @@ -32,6 +33,9 @@ def kl_divergence(a_exp_dissim, b_exp_dissim): b_exp_dissim : torch.Tensor A tensor representing the second probability distribution. + is_generalized: bool + If True, computes generalized KL divergence between two distribution + Returns ------- divergence : torch.Tensor @@ -41,45 +45,24 @@ def kl_divergence(a_exp_dissim, b_exp_dissim): a_exp_dissim.shape[1] == b_exp_dissim.shape[1] ), "X and Y do not have the same number of features." - a_exp_dissim = a_exp_dissim / a_exp_dissim.sum(axis=1, keepdims=True) - b_exp_dissim = b_exp_dissim / b_exp_dissim.sum(axis=1, keepdims=True) - a_log_exp_dissim = a_exp_dissim.log() - b_log_exp_dissim = b_exp_dissim.log() - a_weighted_dissim_sum = torch.sum(a_exp_dissim * a_log_exp_dissim, axis=1)[ - torch.newaxis, : - ] - return a_weighted_dissim_sum.T - torch.matmul(a_exp_dissim, b_log_exp_dissim.T) - - -def generalized_kl_divergence(a_exp_dissim, b_exp_dissim): - """ - Computes the generalized Kullback-Leibler (KL) divergence between two distributions - - Parameters - ---------- - a_exp_dissim : torch.Tensor - A tensor representing first probability distribution. - - b_exp_dissim : torch.Tensor - A tensor representing the second probability distribution. - - Returns - ------- - divergence : torch.Tensor - A tensor containing the generalized Kullback-Leibler divergence for each sample. - """ - assert ( - a_exp_dissim.shape[1] == b_exp_dissim.shape[1] - ), "X and Y do not have the same number of features." + if not is_generalized: + a_exp_dissim = a_exp_dissim / a_exp_dissim.sum(axis=1, keepdims=True) + b_exp_dissim = b_exp_dissim / b_exp_dissim.sum(axis=1, keepdims=True) a_log_exp_dissim = a_exp_dissim.log() b_log_exp_dissim = b_exp_dissim.log() + a_weighted_dissim_sum = torch.sum(a_exp_dissim * a_log_exp_dissim, axis=1)[ torch.newaxis, : ] + divergence = a_weighted_dissim_sum.T - torch.matmul( a_exp_dissim, b_log_exp_dissim.T ) + + if not is_generalized: + return divergence + sum_a_exp_dissim = torch.sum(a_exp_dissim, axis=1) sum_b_exp_dissim = torch.sum(b_exp_dissim, axis=1) return (divergence.T - sum_a_exp_dissim).T + sum_b_exp_dissim.T @@ -257,27 +240,26 @@ def to_dense_array(X): return torch.Tensor(np_array).double() -def filter_for_common_genes(slices: list[AnnData]) -> None: - """ - Filters a list of AnnData objects to retain only the common genes across - all slices. +def get_common_genes(slices: list[AnnData]) -> tuple[list[AnnData], np.ndarray]: + """Returns common genes from multiple slices""" + common_genes = slices[0].var.index - Parameters - ---------- - slices: List[AnnData] - A list of AnnData objects that represent different slices. - """ - assert len(slices) > 0, "Cannot have empty list." + for i, slice in enumerate(slices, start=1): + common_genes = common_genes.intersection(slice.var.index) + if len(common_genes) == 0: + logger.error(f"Slice {i} has no common genes with rest of the slices.") + raise ValueError(f"Slice {i} has no common genes with rest of the slices.") - common_genes = slices[0].var.index - for s in slices: - common_genes = common_genes.intersection(s.var.index) - for i in range(len(slices)): - slices[i] = slices[i][:, common_genes] - logging.info( - "Filtered all slices for common genes. There are " - + str(len(common_genes)) - + " common genes." + return [slice[:, common_genes] for slice in slices], common_genes + + +def compute_slice_weights(slice_weights, pis, slices, device): + return sum( + [ + slice_weights[i] + * torch.matmul(pis[i], to_dense_array(slices[i].X).to(device)) + for i in range(len(slices)) + ] ) @@ -333,41 +315,6 @@ def match_spots_using_spatial_heuristic( return pi -def kl_divergence_backend(a_exp_dissim, b_exp_dissim): - """ - Calculates the Kullback-Leibler divergence between two distributions. - - Parameters - ---------- - a_exp_dissim : torch.Tensor - A tensor representing the first probability distribution. - - b_exp_dissim : torch.Tensor - A tensor representing the second probability distribution. - - Returns - ------- - divergence : np.ndarray - A tensor containing the Kullback-Leibler divergence for each sample. - """ - assert ( - a_exp_dissim.shape[1] == b_exp_dissim.shape[1] - ), "X and Y do not have the same number of features." - - nx = ot.backend.get_backend(a_exp_dissim, b_exp_dissim) - - a_exp_dissim = a_exp_dissim / nx.sum(a_exp_dissim, axis=1, keepdims=True) - b_exp_dissim = b_exp_dissim / nx.sum(b_exp_dissim, axis=1, keepdims=True) - a_log_exp_dissim = nx.log(a_exp_dissim) - b_log_exp_dissim = nx.log(b_exp_dissim) - a_weighted_dissim_sum = nx.einsum("ij,ij->i", a_exp_dissim, a_log_exp_dissim) - a_weighted_dissim_sum = nx.reshape( - a_weighted_dissim_sum, (1, a_weighted_dissim_sum.shape[0]) - ) - divergence = a_weighted_dissim_sum.T - nx.dot(a_exp_dissim, b_log_exp_dissim.T) - return nx.to_numpy(divergence) - - def dissimilarity_metric(which, a_slice, b_slice, a_exp_dissim, b_exp_dissim, **kwargs): """ Computes a dissimilarity matrix between two distribution using a specified @@ -408,7 +355,9 @@ def dissimilarity_metric(which, a_slice, b_slice, a_exp_dissim, b_exp_dissim, ** case "gkl": a_exp_dissim = a_exp_dissim + 0.01 b_exp_dissim = b_exp_dissim + 0.01 - exp_dissim_matrix = generalized_kl_divergence(a_exp_dissim, b_exp_dissim) + exp_dissim_matrix = kl_divergence( + a_exp_dissim, b_exp_dissim, is_generalized=True + ) exp_dissim_matrix /= exp_dissim_matrix[exp_dissim_matrix > 0].max() exp_dissim_matrix *= 10 return exp_dissim_matrix diff --git a/src/paste3/model_selection.py b/src/paste3/model_selection.py index 6d05dca..0a8bdeb 100644 --- a/src/paste3/model_selection.py +++ b/src/paste3/model_selection.py @@ -216,7 +216,6 @@ def select_overlap_fraction(sliceA, sliceB, alpha=0.1, show_plot=True, numIterma exp_dissim_matrix=M, alpha=alpha, norm=True, - return_obj=True, numItermax=numItermax, maxIter=numItermax, ) diff --git a/src/paste3/paste.py b/src/paste3/paste.py index 3331542..5a1dd19 100644 --- a/src/paste3/paste.py +++ b/src/paste3/paste.py @@ -5,6 +5,7 @@ """ import logging +from collections.abc import Callable from typing import Any import numpy as np @@ -16,7 +17,9 @@ from torchnmf.nmf import NMF as TorchNMF from paste3.helper import ( + compute_slice_weights, dissimilarity_metric, + get_common_genes, to_dense_array, ) @@ -36,12 +39,11 @@ def pairwise_align( norm: bool = False, numItermax: int = 200, use_gpu: bool = True, - return_obj: bool = False, maxIter=1000, optimizeTheta=True, eps=1e-4, do_histology: bool = False, -) -> tuple[np.ndarray, int | None]: +) -> tuple[np.ndarray, dict | None]: r""" Returns a mapping :math:`( \Pi = [\pi_{ij}] )` between spots in one slice and spots in another slice while preserving gene expression and spatial distances of mapped spots, where :math:`\pi_{ij}` describes the probability that @@ -113,8 +115,6 @@ def pairwise_align( Maximum number of iterations for the optimization. use_gpu : bool, default=True Whether to use GPU for computations. If True but no GPU is available, will default to CPU. - return_obj : bool, default=False - If True, returns the optimization object along with the transport plan. maxIter : int, default=1000 Maximum number of iterations for the dissimilarity calculation. optimizeTheta : bool, default=True @@ -136,41 +136,19 @@ def pairwise_align( logger.info("GPU is not available, resorting to torch CPU.") use_gpu = False - # subset for common genes - common_genes = a_slice.var.index.intersection(b_slice.var.index) - a_slice = a_slice[:, common_genes] - b_slice = b_slice[:, common_genes] + device = "cuda" if use_gpu else "cpu" - # check if slices are valid - for slice in [a_slice, b_slice]: - if not len(slice): - raise ValueError(f"Found empty `AnnData`:\n{a_slice}.") + slices, _ = get_common_genes([a_slice, b_slice]) + a_slice, b_slice = slices - # Backend - nx = ot.backend.TorchBackend() + a_dist = torch.Tensor(a_slice.obsm["spatial"]).double() + b_dist = torch.Tensor(b_slice.obsm["spatial"]).double() - # Calculate spatial distances - a_coordinates = a_slice.obsm["spatial"].copy() - a_coordinates = nx.from_numpy(a_coordinates) - b_coordinates = b_slice.obsm["spatial"].copy() - b_coordinates = nx.from_numpy(b_coordinates) + a_exp_dissim = to_dense_array(a_slice.X).double().to(device) + b_exp_dissim = to_dense_array(b_slice.X).double().to(device) - a_spatial_dist = ot.dist(a_coordinates, a_coordinates, metric="euclidean") - b_spatial_dist = ot.dist(b_coordinates, b_coordinates, metric="euclidean") - - a_spatial_dist = a_spatial_dist.double() - b_spatial_dist = b_spatial_dist.double() - if use_gpu: - a_spatial_dist = a_spatial_dist.cuda() - b_spatial_dist = b_spatial_dist.cuda() - - # Calculate expression dissimilarity - a_exp_dissim = to_dense_array(a_slice.X) - b_exp_dissim = to_dense_array(b_slice.X) - - if use_gpu: - a_exp_dissim = a_exp_dissim.cuda() - b_exp_dissim = b_exp_dissim.cuda() + a_spatial_dist = torch.cdist(a_dist, a_dist).double().to(device) + b_spatial_dist = torch.cdist(b_dist, b_dist).double().to(device) if exp_dissim_matrix is None: exp_dissim_matrix = dissimilarity_metric( @@ -185,6 +163,7 @@ def pairwise_align( eps=eps, optimizeTheta=optimizeTheta, ) + exp_dissim_matrix = torch.Tensor(exp_dissim_matrix).double().to(device) if do_histology: # Calculate RGB dissimilarity @@ -194,7 +173,7 @@ def pairwise_align( torch.Tensor(b_slice.obsm["rgb"]).double(), ) .to(exp_dissim_matrix.dtype) - .to(exp_dissim_matrix.device) + .to(device) ) # Scale M_exp and rgb_dissim_matrix, obtain M by taking half from each @@ -202,38 +181,32 @@ def pairwise_align( rgb_dissim_matrix *= exp_dissim_matrix.max() exp_dissim_matrix = 0.5 * exp_dissim_matrix + 0.5 * rgb_dissim_matrix - # init distributions if a_spots_weight is None: - a_spots_weight = nx.ones((a_slice.shape[0],)) / a_slice.shape[0] + a_spots_weight = torch.ones((a_slice.shape[0],)) / a_slice.shape[0] + a_spots_weight = a_spots_weight.double().to(device) else: - a_spots_weight = nx.from_numpy(a_spots_weight) + a_spots_weight = torch.Tensor(a_spots_weight).double().to(device) if b_spots_weight is None: - b_spots_weight = nx.ones((b_slice.shape[0],)) / b_slice.shape[0] + b_spots_weight = torch.ones((b_slice.shape[0],)) / b_slice.shape[0] + b_spots_weight = b_spots_weight.double().to(device) else: - b_spots_weight = nx.from_numpy(b_spots_weight) - - exp_dissim_matrix = exp_dissim_matrix.double() - a_spots_weight = a_spots_weight.double() - b_spots_weight = b_spots_weight.double() - if use_gpu: - exp_dissim_matrix = exp_dissim_matrix.cuda() - a_spots_weight = a_spots_weight.cuda() - b_spots_weight = b_spots_weight.cuda() + b_spots_weight = torch.Tensor(b_spots_weight).double().to(device) if norm: - a_spatial_dist /= nx.min(a_spatial_dist[a_spatial_dist > 0]) - b_spatial_dist /= nx.min(b_spatial_dist[b_spatial_dist > 0]) + a_spatial_dist /= torch.min(a_spatial_dist[a_spatial_dist > 0]) + b_spatial_dist /= torch.min(b_spatial_dist[b_spatial_dist > 0]) if overlap_fraction: a_spatial_dist /= a_spatial_dist[a_spatial_dist > 0].max() a_spatial_dist *= exp_dissim_matrix.max() b_spatial_dist /= b_spatial_dist[b_spatial_dist > 0].max() b_spatial_dist *= exp_dissim_matrix.max() - # Run OT - if pi_init is not None and use_gpu: - pi_init.cuda() - pi, info = my_fused_gromov_wasserstein( + if pi_init is not None: + pi_init = torch.Tensor(pi_init).double().to(device) + pi_init = (1 / torch.sum(pi_init)) * pi_init + + return my_fused_gromov_wasserstein( exp_dissim_matrix, a_spatial_dist, b_spatial_dist, @@ -244,14 +217,7 @@ def pairwise_align( pi_init=pi_init, loss_fun="square_loss", numItermax=maxIter if overlap_fraction else numItermax, - use_gpu=use_gpu, ) - if not overlap_fraction: - info = info["fgw_dist"].item() - - if return_obj: - return pi, info - return pi def center_align( @@ -349,81 +315,40 @@ def center_align( logger.info("GPU is not available, resorting to torch CPU.") use_gpu = False + device = "cuda" if use_gpu else "cpu" + if slice_weights is None: slice_weights = len(slices) * [1 / len(slices)] if spots_weights is None: spots_weights = len(slices) * [None] - # get common genes - common_genes = initial_slice.var.index - for s in slices: - common_genes = common_genes.intersection(s.var.index) - - # subset common genes + slices, common_genes = get_common_genes(slices) initial_slice = initial_slice[:, common_genes] - for i in range(len(slices)): - slices[i] = slices[i][:, common_genes] - logger.info( - "Filtered all slices for common genes. There are " - + str(len(common_genes)) - + " common genes." - ) - # Run initial NMF - if exp_dissim_metric.lower() == "euclidean" or exp_dissim_metric.lower() == "euc": - nmf_model = NMF( - n_components=n_components, - init="random", - random_state=random_seed, - ) - else: - nmf_model = NMF( - n_components=n_components, - solver="mu", - beta_loss="kullback-leibler", - init="random", - random_state=random_seed, - ) - - if pi_inits is None: - pis = [None for i in range(len(slices))] - feature_matrix = nmf_model.fit_transform(initial_slice.X) - - else: - pis = pi_inits - feature_matrix = nmf_model.fit_transform( - initial_slice.shape[0] - * sum( - [ - slice_weights[i] * np.dot(pis[i], to_dense_array(slices[i].X)) - for i in range(len(slices)) - ] - ) - ) - coeff_matrix = nmf_model.components_ - center_coordinates = initial_slice.obsm["spatial"] + feature_matrix, coeff_matrix = center_NMF( + initial_slice.X, + slices, + pi_inits, + slice_weights, + n_components, + random_seed, + exp_dissim_metric=exp_dissim_metric, + device=device, + ) - if not isinstance(center_coordinates, np.ndarray): - logger.warning("A.obsm['spatial'] is not of type numpy array.") + pis = [None for _ in slices] if pi_inits is None else pi_inits - # Initialize center_slice - center_slice = AnnData(np.dot(feature_matrix, coeff_matrix)) - center_slice.var.index = common_genes - center_slice.obs.index = initial_slice.obs.index - center_slice.obsm["spatial"] = center_coordinates - - # Minimize loss iteration_count = 0 loss_init = 0 loss_diff = 100 while loss_diff > threshold and iteration_count < max_iter: - logger.info("Iteration: " + str(iteration_count)) + logger.info(f"Iteration: {iteration_count}") pis, loss = center_ot( feature_matrix, coeff_matrix, slices, - center_coordinates, + initial_slice.obsm["spatial"], common_genes, alpha, use_gpu, @@ -441,13 +366,13 @@ def center_align( n_components, random_seed, exp_dissim_metric=exp_dissim_metric, + device=device, fast=fast, ) loss_new = np.dot(loss, slice_weights) iteration_count += 1 loss_diff = abs(loss_init - loss_new) - logger.info(f"Objective {loss_new}") - logger.info(f"Difference: {loss_diff}") + logger.info(f"Objective {loss_new} | Difference: {loss_diff}") loss_init = loss_new if pbar is not None: @@ -459,15 +384,7 @@ def center_align( center_slice.uns["paste_H"] = coeff_matrix center_slice.uns["full_rank"] = ( center_slice.shape[0] - * sum( - [ - slice_weights[i] - * torch.matmul(pis[i], to_dense_array(slices[i].X).to(pis[i].device)) - for i in range(len(slices)) - ] - ) - .cpu() - .numpy() + * compute_slice_weights(slice_weights, pis, slices, device).cpu().numpy() ) center_slice.uns["obj"] = loss_init return center_slice, pis @@ -549,26 +466,26 @@ def center_ot( slices[i], alpha=alpha, exp_dissim_metric=exp_dissim_metric, - norm=norm, - numItermax=numItermax, - return_obj=True, pi_init=pi_inits[i], b_spots_weight=spot_weights[i], + norm=norm, + numItermax=numItermax, use_gpu=use_gpu, ) pis.append(pi) - losses.append(loss) + losses.append(loss["loss"][-1].item()) return pis, np.array(losses) def center_NMF( feature_matrix: np.ndarray, slices: list[AnnData], - pis: list[torch.Tensor], + pis: list[torch.Tensor] | None, slice_weights: list[float] | None, n_components: int, random_seed: float, exp_dissim_metric: str = "kl", + device="cpu", fast: bool = False, ): r""" @@ -607,39 +524,34 @@ def center_NMF( The updated matrix of coefficients resulting from the NMF decomposition. """ logger.info("Solving Center Mapping NMF Problem.") - n_features = feature_matrix.shape[0] - weighted_features = n_features * sum( - [ - slice_weights[i] - * torch.matmul(pis[i], to_dense_array(slices[i].X).to(pis[i].device)) - for i in range(len(slices)) - ] - ) - if exp_dissim_metric.lower() == "euclidean" or exp_dissim_metric.lower() == "euc": - nmf_model = NMF( - n_components=n_components, - init="random", - random_state=random_seed, - ) - elif fast: - nmf_model = TorchNMF(weighted_features.T.shape, rank=n_components).to( - weighted_features.device + + if pis is not None: + pis = [torch.Tensor(pi).double().to(device) for pi in pis] + feature_matrix = ( + feature_matrix.shape[0] + * compute_slice_weights(slice_weights, pis, slices, device).cpu().numpy() ) + feature_matrix = torch.Tensor(feature_matrix).to(device) + if fast: + nmf_model = TorchNMF(feature_matrix.T.shape, rank=n_components) else: + exp_dissim_metric = exp_dissim_metric.lower() nmf_model = NMF( n_components=n_components, - solver="mu", - beta_loss="kullback-leibler", init="random", random_state=random_seed, + solver="cd" if exp_dissim_metric[:3] == "euc" else "mu", + beta_loss="frobenius" + if exp_dissim_metric[:3] == "euc" + else "kullback-leibler", ) if fast: - nmf_model.fit(weighted_features.T) + nmf_model.to(feature_matrix).fit(feature_matrix.T) new_feature_matrix = nmf_model.W.double().detach().cpu().numpy() new_coeff_matrix = nmf_model.H.T.detach().cpu().numpy() else: - new_feature_matrix = nmf_model.fit_transform(weighted_features.cpu().numpy()) + new_feature_matrix = nmf_model.fit_transform(feature_matrix.cpu()) new_coeff_matrix = nmf_model.components_ return new_feature_matrix, new_coeff_matrix @@ -658,11 +570,10 @@ def my_fused_gromov_wasserstein( numItermax: int | None = 200, tol_rel: float | None = 1e-9, tol_abs: float | None = 1e-9, - use_gpu: bool | None = True, numItermaxEmd: int | None = 100000, dummy: int | None = 1, **kwargs, -): +) -> tuple[np.ndarray, dict]: """ Computes a transport plan to align two weighted spatial distributions based on expression dissimilarity matrix and spatial distances, using the Gromov-Wasserstein framework. @@ -697,8 +608,6 @@ def my_fused_gromov_wasserstein( Relative tolerance for convergence, by default 1e-9. tol_abs : float, Optional Absolute tolerance for convergence, by default 1e-9. - use_gpu : bool, Optional - Whether to use GPU for computations. If True but no GPU is available, will default to CPU. numItermaxEmd : int, Optional Maximum iterations for Earth Mover's Distance (EMD) solver. dummy : int, Optional @@ -716,16 +625,6 @@ def my_fused_gromov_wasserstein( For more info, see: https://pythonot.github.io/gen_modules/ot.gromov.html """ - a_spots_weight, b_spots_weight = ot.utils.list_to_array( - a_spots_weight, b_spots_weight - ) - nx = ot.backend.get_backend( - a_spots_weight, - b_spots_weight, - a_spatial_dist, - b_spatial_dist, - exp_dissim_matrix, - ) if overlap_fraction is not None: if overlap_fraction < 0: @@ -756,36 +655,26 @@ def my_fused_gromov_wasserstein( ] ) - if pi_init is not None: - pi_init = (1 / nx.sum(pi_init)) * pi_init - if use_gpu: - pi_init = pi_init.cuda() - - def f_loss(pi): - """Compute the Gromov-Wasserstein loss for a given transport plan.""" - combined_spatial_cost, a_gradient, b_gradient = ot.gromov.init_matrix( + def transform_matrix(pi): + p, q = torch.sum(pi, axis=1), torch.sum(pi, axis=0) + return ot.gromov.init_matrix( a_spatial_dist, b_spatial_dist, - nx.sum(pi, axis=1).reshape(-1, 1).to(a_spatial_dist.dtype), - nx.sum(pi, axis=0).reshape(1, -1).to(b_spatial_dist.dtype), + p, + q, loss_fun, ) + + def f_loss(pi): + """Compute the Gromov-Wasserstein loss for a given transport plan.""" + combined_spatial_cost, a_gradient, b_gradient = transform_matrix(pi) return ot.gromov.gwloss(combined_spatial_cost, a_gradient, b_gradient, pi) def f_gradient(pi): """Compute the gradient of the Gromov-Wasserstein loss for a given transport plan.""" - combined_spatial_cost, a_gradient, b_gradient = ot.gromov.init_matrix( - a_spatial_dist, - b_spatial_dist, - nx.sum(pi, axis=1).reshape(-1, 1), - nx.sum(pi, axis=0).reshape(1, -1), - loss_fun, - ) + combined_spatial_cost, a_gradient, b_gradient = transform_matrix(pi) return ot.gromov.gwggrad(combined_spatial_cost, a_gradient, b_gradient, pi) - if loss_fun == "kl_loss": - armijo = True # there is no closed form line-search with KL - def line_search(f_cost, pi, pi_diff, linearized_matrix, cost_pi, _, **kwargs): """Solve the linesearch in the fused wasserstein iterations""" if overlap_fraction: @@ -795,9 +684,9 @@ def line_search(f_cost, pi, pi_diff, linearized_matrix, cost_pi, _, **kwargs): _info["err"].append(torch.norm(pi_diff)) count += 1 - if armijo: + if loss_fun == "kl_loss" or armijo: return ot.optim.line_search_armijo( - f_cost, pi, pi_diff, linearized_matrix, cost_pi, nx=nx, **kwargs + f_cost, pi, pi_diff, linearized_matrix, cost_pi, **kwargs ) if overlap_fraction: return line_search_partial( @@ -807,7 +696,8 @@ def line_search(f_cost, pi, pi_diff, linearized_matrix, cost_pi, _, **kwargs): a_spatial_dist, b_spatial_dist, pi_diff, - loss_fun=loss_fun, + f_cost, + f_gradient, ) return ot.gromov.solve_gromov_linesearch( G=pi, @@ -817,7 +707,6 @@ def line_search(f_cost, pi, pi_diff, linearized_matrix, cost_pi, _, **kwargs): C2=b_spatial_dist, M=0.0, reg=2 * 1.0, - nx=nx, **kwargs, ) @@ -849,7 +738,7 @@ def lp_solver( log=True, ) - return_val = ot.optim.generic_conditional_gradient( + pi, info = ot.optim.generic_conditional_gradient( a=a_spots_weight, b=b_spots_weight, M=(1 - alpha) * exp_dissim_matrix, @@ -866,15 +755,8 @@ def lp_solver( stopThr2=tol_abs, **kwargs, ) - - pi, info = return_val if overlap_fraction: - info["partial_fgw_cost"] = info["loss"][-1] info["err"] = _info["err"] - else: - info["fgw_dist"] = info["loss"][-1] - info["u"] = info["u"] - info["v"] = info["v"] return pi, info @@ -885,7 +767,8 @@ def line_search_partial( a_spatial_dist: torch.Tensor, b_spatial_dist: torch.Tensor, pi_diff: torch.Tensor, - loss_fun: str = "square_loss", + f_cost: Callable, + f_gradient: Callable, ): """ Solve the linesearch in the fused wasserstein iterations for partially overlapping slices @@ -917,31 +800,12 @@ def line_search_partial( cost_G : float The final cost after the update of the transport plan. """ - combined_spatial_cost, a_gradient, b_gradient = ot.gromov.init_matrix( - a_spatial_dist, - b_spatial_dist, - torch.sum(pi_diff, axis=1).reshape(-1, 1), - torch.sum(pi_diff, axis=0).reshape(1, -1), - loss_fun, - ) dot = torch.matmul(torch.matmul(a_spatial_dist, pi_diff), b_spatial_dist.T) a = alpha * torch.sum(dot * pi_diff) b = (1 - alpha) * torch.sum(exp_dissim_matrix * pi_diff) + 2 * alpha * torch.sum( - ot.gromov.gwggrad(combined_spatial_cost, a_gradient, b_gradient, pi_diff) - * 0.5 - * pi + f_gradient(pi_diff) * 0.5 * pi ) minimal_cost = ot.optim.solve_1d_linesearch_quad(a, b) - pi = pi + minimal_cost * pi_diff - combined_spatial_cost, a_gradient, b_gradient = ot.gromov.init_matrix( - a_spatial_dist, - b_spatial_dist, - torch.sum(pi, axis=1).reshape(-1, 1), - torch.sum(pi, axis=0).reshape(1, -1), - loss_fun, - ) - cost_G = (1 - alpha) * torch.sum(exp_dissim_matrix * pi) + alpha * ot.gromov.gwloss( - combined_spatial_cost, a_gradient, b_gradient, pi - ) + cost_G = f_cost(pi + minimal_cost * pi_diff) return minimal_cost, a, cost_G diff --git a/src/paste3/visualization.py b/src/paste3/visualization.py index b02308b..317be8a 100644 --- a/src/paste3/visualization.py +++ b/src/paste3/visualization.py @@ -15,10 +15,7 @@ def stack_slices_pairwise( - slices: list[AnnData], - pis: list[np.ndarray], - return_params: bool = False, - is_partial: bool = False, + slices: list[AnnData], pis: list[np.ndarray], is_partial: bool = False ) -> tuple[list[AnnData], list[float] | None, list[np.ndarray] | None]: """ Align spatial coordinates of sequential pairwise slices. @@ -50,50 +47,39 @@ def stack_slices_pairwise( aligned_coordinates = [] rotation_angles = [] translations = [] - result = generalized_procrustes_analysis( + ( + source_coordinates, + target_coordinates, + rotation_angle, + x_translation, + y_translation, + ) = generalized_procrustes_analysis( torch.Tensor(slices[0].obsm["spatial"]).to(pis[0].dtype).to(pis[0].device), torch.Tensor(slices[1].obsm["spatial"]).to(pis[0].dtype).to(pis[0].device), pis[0], is_partial=is_partial, - return_params=return_params, ) - if return_params: + rotation_angles.append(rotation_angle) + translations.append((x_translation, y_translation)) + aligned_coordinates.append(source_coordinates) + aligned_coordinates.append(target_coordinates) + for i in range(1, len(slices) - 1): ( source_coordinates, target_coordinates, rotation_angle, x_translation, y_translation, - ) = result - rotation_angles.append(rotation_angle) - translations.append((x_translation, y_translation)) - else: - source_coordinates, target_coordinates = result - aligned_coordinates.append(source_coordinates) - aligned_coordinates.append(target_coordinates) - for i in range(1, len(slices) - 1): - result = generalized_procrustes_analysis( + ) = generalized_procrustes_analysis( aligned_coordinates[i], torch.Tensor(slices[i + 1].obsm["spatial"]) .to(pis[i].dtype) .to(pis[i].device), pis[i], is_partial=is_partial, - return_params=return_params, ) - if return_params: - ( - source_coordinates, - target_coordinates, - rotation_angle, - x_translation, - y_translation, - ) = result - rotation_angles.append(rotation_angle) - translations.append((x_translation, y_translation)) - else: - source_coordinates, target_coordinates = result - + rotation_angles.append(rotation_angle) + translations.append((x_translation, y_translation)) if is_partial: shift = aligned_coordinates[i][0, :] - source_coordinates[0, :] target_coordinates = target_coordinates + shift @@ -105,17 +91,11 @@ def stack_slices_pairwise( _slice.obsm["spatial"] = aligned_coordinates[i].cpu().numpy() new_slices.append(_slice) - if not return_params: - return new_slices return new_slices, rotation_angles, translations def stack_slices_center( - center_slice: AnnData, - slices: list[AnnData], - pis: list[np.ndarray], - matrix: bool = False, - output_params: bool = False, + center_slice: AnnData, slices: list[AnnData], pis: list[np.ndarray] ) -> tuple[AnnData, list[AnnData], list[float] | None, list[np.ndarray] | None]: """ Align spatial coordinates of a list of slices to a center_slice. @@ -155,37 +135,21 @@ def stack_slices_center( translations = [] for i in range(len(slices)): - logger.info(f"Aligning slice {i} to center slice") - if not output_params: - source_coordinates, target_coordinates = generalized_procrustes_analysis( - torch.Tensor(center_slice.obsm["spatial"]) - .to(pis[i].dtype) - .to(pis[i].device), - torch.Tensor(slices[i].obsm["spatial"]) - .to(pis[i].dtype) - .to(pis[i].device), - pis[i], - ) - else: - ( - source_coordinates, - target_coordinates, - rotation_angle, - x_translation, - y_translation, - ) = generalized_procrustes_analysis( - torch.Tensor(center_slice.obsm["spatial"]) - .to(pis[i].dtype) - .to(pis[i].device), - torch.Tensor(slices[i].obsm["spatial"]) - .to(pis[i].dtype) - .to(pis[i].device), - pis[i], - return_params=output_params, - return_as_matrix=matrix, - ) - rotation_angles.append(rotation_angle) - translations.append((x_translation, y_translation)) + ( + source_coordinates, + target_coordinates, + rotation_angle, + x_translation, + y_translation, + ) = generalized_procrustes_analysis( + torch.Tensor(center_slice.obsm["spatial"]) + .to(pis[i].dtype) + .to(pis[i].device), + torch.Tensor(slices[i].obsm["spatial"]).to(pis[i].dtype).to(pis[i].device), + pis[i], + ) + rotation_angles.append(rotation_angle) + translations.append((x_translation, y_translation)) aligned_coordinates.append(target_coordinates) new_slices = [] @@ -196,8 +160,6 @@ def stack_slices_center( new_center = center_slice.copy() new_center.obsm["spatial"] = source_coordinates.cpu().numpy() - if not output_params: - return new_center, new_slices return new_center, new_slices, rotation_angles, translations @@ -228,12 +190,7 @@ def plot_slice( def generalized_procrustes_analysis( - source_coordinates, - target_coordinates, - pi, - return_params=False, - return_as_matrix=False, - is_partial=False, + source_coordinates, target_coordinates, pi, is_partial=False ): """ Finds and applies optimal rotation between spatial coordinates of two layers (may also do a reflection). @@ -268,24 +225,14 @@ def generalized_procrustes_analysis( U, S, Vt = torch.linalg.svd(covariance_matrix, full_matrices=True) rotation_matrix = Vt.T.matmul(U.T) target_coordinates = rotation_matrix.matmul(target_coordinates.T).T - if return_params and not return_as_matrix: - M = torch.Tensor([[0, -1], [1, 0]]).to(covariance_matrix) - rotation_angle = torch.arctan( - torch.trace(M.matmul(covariance_matrix)) / torch.trace(covariance_matrix) - ) - return ( - source_coordinates, - target_coordinates, - rotation_angle, - weighted_source, - weighted_target, - ) - if return_params and return_as_matrix: - return ( - source_coordinates, - target_coordinates, - rotation_matrix, - weighted_source, - weighted_target, - ) - return source_coordinates, target_coordinates + M = torch.Tensor([[0, -1], [1, 0]]).to(covariance_matrix) + rotation_angle = torch.arctan( + torch.trace(M.matmul(covariance_matrix)) / torch.trace(covariance_matrix) + ) + return ( + source_coordinates, + target_coordinates, + rotation_angle, + weighted_source, + weighted_target, + ) diff --git a/tests/test_model_selection.py b/tests/test_model_selection.py index 62783a6..7675a91 100644 --- a/tests/test_model_selection.py +++ b/tests/test_model_selection.py @@ -53,11 +53,11 @@ def test_edge_inconsistency_score(): def test_calculate_convex_hull_edge_inconsistency(slices): - pairwise_info = pairwise_align( + pairwise_info, _ = pairwise_align( slices[0], slices[1], - exp_dissim_metric="glmpca", overlap_fraction=0.7, + exp_dissim_metric="glmpca", norm=True, maxIter=10, ) diff --git a/tests/test_paste.py b/tests/test_paste.py index 1ae71e7..f775203 100644 --- a/tests/test_paste.py +++ b/tests/test_paste.py @@ -43,14 +43,14 @@ def assert_checksum_equals(temp_dir, filename, loose=False): def test_pairwise_alignment(slices): - outcome = pairwise_align( + outcome, _ = pairwise_align( slices[0], slices[1], alpha=0.1, exp_dissim_metric="kl", + pi_init=None, a_spots_weight=slices[0].obsm["weights"].astype(slices[0].X.dtype), b_spots_weight=slices[1].obsm["weights"].astype(slices[1].X.dtype), - pi_init=None, use_gpu=True, ) probability_mapping = pd.DataFrame( @@ -214,8 +214,6 @@ def test_fused_gromov_wasserstein(spot_distance_matrix): def test_gromov_linesearch(spot_distance_matrix): - nx = ot.backend.TorchBackend() - G = 1.509115054931788e-05 * torch.ones((251, 264)).double() deltaG = torch.Tensor( np.genfromtxt(input_dir / "deltaG.csv", delimiter=",") @@ -230,7 +228,6 @@ def test_gromov_linesearch(spot_distance_matrix): C2=spot_distance_matrix[2], M=0.0, reg=2 * 1.0, - nx=nx, ) assert alpha == 1.0 assert fc == 1 @@ -238,6 +235,7 @@ def test_gromov_linesearch(spot_distance_matrix): def test_line_search_partial(spot_distance_matrix): + d1, d2 = spot_distance_matrix[1], spot_distance_matrix[2] G = 1.509115054931788e-05 * torch.ones((251, 264)).double() deltaG = torch.Tensor( np.genfromtxt(input_dir / "deltaG.csv", delimiter=",") @@ -245,15 +243,30 @@ def test_line_search_partial(spot_distance_matrix): M = torch.Tensor( np.genfromtxt(input_dir / "euc_dissimilarity.csv", delimiter=",") ).double() + alpha = 0.1 - alpha, a, cost_G = line_search_partial( - alpha=0.1, + def f_cost(pi): + p, q = torch.sum(pi, axis=1), torch.sum(pi, axis=0) + constC, hC1, hC2 = ot.gromov.init_matrix(d1, d2, p, q) + return (1 - alpha) * torch.sum(M * pi) + alpha * ot.gromov.gwloss( + constC, hC1, hC2, pi + ) + + def f_gradient(pi): + p, q = torch.sum(pi, axis=1), torch.sum(pi, axis=0) + constC, hC1, hC2 = ot.gromov.init_matrix(d1, d2, p, q) + return ot.gromov.gwggrad(constC, hC1, hC2, pi) + + minimal_cost, a, cost_G = line_search_partial( + alpha=alpha, exp_dissim_matrix=M, pi=G, a_spatial_dist=spot_distance_matrix[1], b_spatial_dist=spot_distance_matrix[2], pi_diff=deltaG, + f_cost=f_cost, + f_gradient=f_gradient, ) - assert alpha == 1.0 + assert minimal_cost == 1.0 assert pytest.approx(a) == 0.4858849047237918 assert pytest.approx(cost_G) == 102.6333512778727 diff --git a/tests/test_paste2.py b/tests/test_paste2.py index f79623d..3c08fa8 100644 --- a/tests/test_paste2.py +++ b/tests/test_paste2.py @@ -22,7 +22,7 @@ def test_partial_pairwise_align_glmpca(fn, slices2): data = np.load(output_dir / "test_partial_pairwise_align.npz") fn.return_value = torch.Tensor(data["glmpca"]).double() - pi_BC = pairwise_align( + pi_BC, _ = pairwise_align( slices2[0], slices2[1], overlap_fraction=0.7, @@ -52,7 +52,6 @@ def test_partial_pairwise_align_given_cost_matrix(slices): exp_dissim_matrix=glmpca_distance_matrix, alpha=0.1, norm=True, - return_obj=True, numItermax=10, maxIter=10, ) @@ -62,7 +61,7 @@ def test_partial_pairwise_align_given_cost_matrix(slices): pd.read_csv(output_dir / "align_given_cost_matrix_pairwise_info.csv"), rtol=1e-04, ) - assert log["partial_fgw_cost"].cpu().numpy() == pytest.approx(40.86494022326222) + assert log["loss"][-1].cpu().numpy() == pytest.approx(40.86494022326222) def test_partial_pairwise_align_histology(slices2): @@ -71,14 +70,13 @@ def test_partial_pairwise_align_histology(slices2): slices2[1], overlap_fraction=0.7, alpha=0.1, - return_obj=True, exp_dissim_metric="euclidean", norm=True, numItermax=10, maxIter=10, do_histology=True, ) - assert log["partial_fgw_cost"].cpu().numpy() == pytest.approx(88.06713721008786) + assert log["loss"][-1].cpu().numpy() == pytest.approx(88.06713721008786) assert np.allclose( pairwise_info.cpu().numpy(), pd.read_csv(output_dir / "partial_pairwise_align_histology.csv").to_numpy(), @@ -172,4 +170,7 @@ def test_partial_fused_gromov_wasserstein(slices, armijo, expected_log, filename ) for k, v in expected_log.items(): - assert np.allclose(log[k], v, rtol=1e-05) + if k == "partial_fgw_cost": + assert np.allclose(log["loss"][-1], v, rtol=1e-05) + else: + assert np.allclose(log[k], v, rtol=1e-05) diff --git a/tests/test_paste_helpers.py b/tests/test_paste_helpers.py index 0d6e643..a51937c 100644 --- a/tests/test_paste_helpers.py +++ b/tests/test_paste_helpers.py @@ -8,12 +8,10 @@ from paste3.helper import ( dissimilarity_metric, - filter_for_common_genes, - generalized_kl_divergence, + get_common_genes, glmpca_distance, high_umi_gene_distance, kl_divergence, - kl_divergence_backend, match_spots_using_spatial_heuristic, norm_and_center_coordinates, pca_distance, @@ -36,10 +34,10 @@ def test_intersect(slices): def test_kl_divergence_backend(): - X = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) - Y = np.array([[2, 4, 6], [8, 10, 12], [14, 16, 28]]) + X = torch.Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])).double() + Y = torch.Tensor(np.array([[2, 4, 6], [8, 10, 12], [14, 16, 28]])).double() - kl_divergence_matrix = kl_divergence_backend(X, Y) + kl_divergence_matrix = kl_divergence(X, Y) expected_kl_divergence_matrix = np.array( [ [0.0, 0.03323784, 0.01889736], @@ -75,9 +73,7 @@ def test_kl_divergence(): def test_filter_for_common_genes(slices): - # creating a copy of the original list - slices = list(slices) - filter_for_common_genes(slices) + slices, _ = get_common_genes(slices) common_genes = list(np.genfromtxt(output_dir / "common_genes.csv", dtype=str)) for slice in slices: @@ -88,7 +84,7 @@ def test_generalized_kl_divergence(): X = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).double() Y = torch.Tensor([[2, 4, 6], [8, 10, 12], [14, 16, 28]]).double() - generalized_kl_divergence_matrix = generalized_kl_divergence(X, Y) + generalized_kl_divergence_matrix = kl_divergence(X, Y, is_generalized=True) expected_kl_divergence_matrix = np.array( [ [1.84111692, 14.54279955, 38.50128292], @@ -164,9 +160,7 @@ def test_high_umi_gene_distance(slices): [(True, "spots_mapping_true.csv"), (False, "spots_mapping_false.csv")], ) def test_match_spots_using_spatial_heuristic(slices, _use_ot, filename): # noqa: PT019 - # creating a copy of the original list - slices = list(slices) - filter_for_common_genes(slices) + slices, _ = get_common_genes(slices) spots_mapping = match_spots_using_spatial_heuristic( slices[0].X, slices[1].X, use_ot=bool(_use_ot) diff --git a/tests/test_visualization.py b/tests/test_visualization.py index 6894e61..8afde62 100644 --- a/tests/test_visualization.py +++ b/tests/test_visualization.py @@ -27,9 +27,7 @@ def test_stack_slices_pairwise(slices): for i in range(1, n_slices) ] - new_slices, thetas, translations = stack_slices_pairwise( - slices, pairwise_info, return_params=True - ) + new_slices, thetas, translations = stack_slices_pairwise(slices, pairwise_info) for i, slice in enumerate(new_slices, start=1): assert_frame_equal( @@ -60,7 +58,7 @@ def test_stack_slices_center(slices): ] new_center, new_slices, thetas, translations = stack_slices_center( - center_slice, slices, pairwise_info, output_params=True + center_slice, slices, pairwise_info ) assert_frame_equal( pd.DataFrame(new_center.obsm["spatial"], columns=["0", "1"]), @@ -104,7 +102,6 @@ def test_generalized_procrustes_analysis(slices): torch.Tensor(center_slice.obsm["spatial"]).double(), torch.Tensor(slices[0].obsm["spatial"]).double(), pairwise_info, - return_params=True, ) ) @@ -155,7 +152,7 @@ def test_partial_stack_slices_pairwise(slices): for i in range(1, n_slices) ] - new_slices = stack_slices_pairwise(slices, pairwise_info, is_partial=True) + new_slices, _, _ = stack_slices_pairwise(slices, pairwise_info, is_partial=True) for i, slice in enumerate(new_slices, start=1): assert_frame_equal( @@ -170,7 +167,7 @@ def test_partial_procrustes_analysis(slices2): assert torch.sum(torch.Tensor(data["pi"])) < 0.99999999 - x_aligned, y_aligned = generalized_procrustes_analysis( + x_aligned, y_aligned, _, _, _ = generalized_procrustes_analysis( torch.Tensor(slices2[0].obsm["spatial"]).double(), torch.Tensor(slices2[1].obsm["spatial"]).double(), torch.Tensor(data["pi"]).double(),