Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Major Refactoring of the Code Base #84

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/source/notebooks/paste2_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@
"metadata": {},
"outputs": [],
"source": [
"pi_AB = paste.pairwise_align(sliceA, sliceB, overlap_fraction=0.7, maxIter=20)"
"pi_AB, _ = paste.pairwise_align(sliceA, sliceB, overlap_fraction=0.7, maxIter=20)"
]
},
{
Expand All @@ -178,7 +178,7 @@
"metadata": {},
"outputs": [],
"source": [
"pi_BC = paste.pairwise_align(sliceB, sliceC, overlap_fraction=0.7, maxIter=20)"
"pi_BC, _ = paste.pairwise_align(sliceB, sliceC, overlap_fraction=0.7, maxIter=20)"
]
},
{
Expand All @@ -188,7 +188,7 @@
"metadata": {},
"outputs": [],
"source": [
"pi_CD = paste.pairwise_align(sliceC, sliceD, overlap_fraction=0.7, maxIter=20)"
"pi_CD, _ = paste.pairwise_align(sliceC, sliceD, overlap_fraction=0.7, maxIter=20)"
]
},
{
Expand Down Expand Up @@ -335,7 +335,7 @@
"pis = [pi_AB, pi_BC, pi_CD]\n",
"slices = [sliceA, sliceB, sliceC, sliceD]\n",
"\n",
"new_slices = visualization.stack_slices_pairwise(slices, pis, is_partial=True)"
"new_slices, _, _ = visualization.stack_slices_pairwise(slices, pis, is_partial=True)"
]
},
{
Expand Down
20 changes: 10 additions & 10 deletions docs/source/notebooks/paste_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
"import scanpy as sc\n",
"import torch\n",
"\n",
"from paste3.helper import filter_for_common_genes, match_spots_using_spatial_heuristic\n",
"from paste3.helper import get_common_genes, match_spots_using_spatial_heuristic\n",
"from paste3.paste import center_align, pairwise_align\n",
"from paste3.visualization import plot_slice, stack_slices_center, stack_slices_pairwise"
]
Expand Down Expand Up @@ -186,9 +186,9 @@
"source": [
"start = time.time()\n",
"\n",
"pi12 = pairwise_align(slice1, slice2)\n",
"pi23 = pairwise_align(slice2, slice3)\n",
"pi34 = pairwise_align(slice3, slice4)\n",
"pi12, _ = pairwise_align(slice1, slice2)\n",
"pi23, _ = pairwise_align(slice2, slice3)\n",
"pi34, _ = pairwise_align(slice3, slice4)\n",
"\n",
"print(\"Runtime: \" + str(time.time() - start))"
]
Expand Down Expand Up @@ -218,7 +218,7 @@
"pis = [pi12, pi23, pi34]\n",
"slices = [slice1, slice2, slice3, slice4]\n",
"\n",
"new_slices = stack_slices_pairwise(slices, pis)"
"new_slices, _, _ = stack_slices_pairwise(slices, pis)"
]
},
{
Expand Down Expand Up @@ -368,7 +368,7 @@
"metadata": {},
"outputs": [],
"source": [
"filter_for_common_genes(slices)\n",
"slices, _ = get_common_genes(slices)\n",
"\n",
"b = []\n",
"for i in range(len(slices)):\n",
Expand Down Expand Up @@ -455,7 +455,7 @@
"metadata": {},
"outputs": [],
"source": [
"center, new_slices = stack_slices_center(center_slice, slices, pis)"
"center, new_slices, _, _ = stack_slices_center(center_slice, slices, pis)"
]
},
{
Expand Down Expand Up @@ -645,9 +645,9 @@
"source": [
"start = time.time()\n",
"\n",
"pi12 = pairwise_align(slice1, slice2, use_gpu=True)\n",
"pi23 = pairwise_align(slice2, slice3, use_gpu=True)\n",
"pi34 = pairwise_align(slice3, slice4, use_gpu=True)\n",
"pi12, _ = pairwise_align(slice1, slice2, use_gpu=True)\n",
"pi23, _ = pairwise_align(slice2, slice3, use_gpu=True)\n",
"pi34, _ = pairwise_align(slice3, slice4, use_gpu=True)\n",
"\n",
"print(\"Runtime: \" + str(time.time() - start))"
]
Expand Down
5 changes: 1 addition & 4 deletions src/paste3/align.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def align(
norm=False,
numItermax=200,
use_gpu=True,
return_obj=False,
optimizeTheta=True,
eps=1e-4,
is_histology=False,
Expand Down Expand Up @@ -82,7 +81,7 @@ def align(
logger.info("Computing Pairwise Alignment ")
pis = []
for i in range(n_slices - 1):
pi = pairwise_align(
pi, _ = pairwise_align(
a_slice=slices[i],
b_slice=slices[i + 1],
overlap_fraction=overlap_fraction,
Expand All @@ -95,7 +94,6 @@ def align(
norm=norm,
numItermax=numItermax,
use_gpu=use_gpu,
return_obj=return_obj,
maxIter=max_iter,
optimizeTheta=optimizeTheta,
eps=eps,
Expand Down Expand Up @@ -245,6 +243,5 @@ def main(args):
norm=args.norm,
numItermax=args.max_iter,
use_gpu=args.gpu,
return_obj=args.r_info,
is_histology=args.hist,
)
18 changes: 8 additions & 10 deletions src/paste3/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,15 +160,14 @@ def find_pis(self, overlap_fraction: float | list[float], max_iters: int = 1000)
pis = []
for i in range(len(self) - 1):
logger.info(f"Finding Pi for slices {i} and {i+1}")
pis.append(
pairwise_align(
self.slices[i].adata,
self.slices[i + 1].adata,
overlap_fraction=overlap_fraction[i],
numItermax=max_iters,
maxIter=max_iters,
)
pi, _ = pairwise_align(
self.slices[i].adata,
self.slices[i + 1].adata,
overlap_fraction=overlap_fraction[i],
numItermax=max_iters,
maxIter=max_iters,
)
pis.append(pi)
return pis

def pairwise_align(
Expand All @@ -180,7 +179,7 @@ def pairwise_align(
if pis is None:
pis = self.find_pis(overlap_fraction=overlap_fraction, max_iters=max_iters)
new_slices, rotation_angles, translations = stack_slices_pairwise(
self.slices_adata, pis, return_params=True
self.slices_adata, pis
)
aligned_dataset = AlignmentDataset(
slices=[Slice(adata=s) for s in new_slices],
Expand Down Expand Up @@ -241,7 +240,6 @@ def center_align(
center_slice=center_slice.adata,
slices=self.slices_adata,
pis=pis,
output_params=True,
)
aligned_dataset = AlignmentDataset(
slices=[Slice(adata=s) for s in new_slices],
Expand Down
123 changes: 36 additions & 87 deletions src/paste3/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
logger = logging.getLogger(__name__)


def kl_divergence(a_exp_dissim, b_exp_dissim):
def kl_divergence(a_exp_dissim, b_exp_dissim, is_generalized=False):
"""
Calculates the Kullback-Leibler divergence between two distributions.
Calculates the Kullback-Leibler divergence (KL) or generalized KL
divergence between two distributions.

Parameters
----------
Expand All @@ -32,6 +33,9 @@ def kl_divergence(a_exp_dissim, b_exp_dissim):
b_exp_dissim : torch.Tensor
A tensor representing the second probability distribution.

is_generalized: bool
If True, computes generalized KL divergence between two distribution

Returns
-------
divergence : torch.Tensor
Expand All @@ -41,45 +45,24 @@ def kl_divergence(a_exp_dissim, b_exp_dissim):
a_exp_dissim.shape[1] == b_exp_dissim.shape[1]
), "X and Y do not have the same number of features."

a_exp_dissim = a_exp_dissim / a_exp_dissim.sum(axis=1, keepdims=True)
b_exp_dissim = b_exp_dissim / b_exp_dissim.sum(axis=1, keepdims=True)
a_log_exp_dissim = a_exp_dissim.log()
b_log_exp_dissim = b_exp_dissim.log()
a_weighted_dissim_sum = torch.sum(a_exp_dissim * a_log_exp_dissim, axis=1)[
torch.newaxis, :
]
return a_weighted_dissim_sum.T - torch.matmul(a_exp_dissim, b_log_exp_dissim.T)


def generalized_kl_divergence(a_exp_dissim, b_exp_dissim):
"""
Computes the generalized Kullback-Leibler (KL) divergence between two distributions

Parameters
----------
a_exp_dissim : torch.Tensor
A tensor representing first probability distribution.

b_exp_dissim : torch.Tensor
A tensor representing the second probability distribution.

Returns
-------
divergence : torch.Tensor
A tensor containing the generalized Kullback-Leibler divergence for each sample.
"""
assert (
a_exp_dissim.shape[1] == b_exp_dissim.shape[1]
), "X and Y do not have the same number of features."
if not is_generalized:
a_exp_dissim = a_exp_dissim / a_exp_dissim.sum(axis=1, keepdims=True)
b_exp_dissim = b_exp_dissim / b_exp_dissim.sum(axis=1, keepdims=True)

a_log_exp_dissim = a_exp_dissim.log()
b_log_exp_dissim = b_exp_dissim.log()

a_weighted_dissim_sum = torch.sum(a_exp_dissim * a_log_exp_dissim, axis=1)[
torch.newaxis, :
]

divergence = a_weighted_dissim_sum.T - torch.matmul(
a_exp_dissim, b_log_exp_dissim.T
)

if not is_generalized:
return divergence

sum_a_exp_dissim = torch.sum(a_exp_dissim, axis=1)
sum_b_exp_dissim = torch.sum(b_exp_dissim, axis=1)
return (divergence.T - sum_a_exp_dissim).T + sum_b_exp_dissim.T
Expand Down Expand Up @@ -257,27 +240,26 @@ def to_dense_array(X):
return torch.Tensor(np_array).double()


def filter_for_common_genes(slices: list[AnnData]) -> None:
"""
Filters a list of AnnData objects to retain only the common genes across
all slices.
def get_common_genes(slices: list[AnnData]) -> tuple[list[AnnData], np.ndarray]:
"""Returns common genes from multiple slices"""
common_genes = slices[0].var.index

Parameters
----------
slices: List[AnnData]
A list of AnnData objects that represent different slices.
"""
assert len(slices) > 0, "Cannot have empty list."
for i, slice in enumerate(slices, start=1):
common_genes = common_genes.intersection(slice.var.index)
if len(common_genes) == 0:
logger.error(f"Slice {i} has no common genes with rest of the slices.")
raise ValueError(f"Slice {i} has no common genes with rest of the slices.")

common_genes = slices[0].var.index
for s in slices:
common_genes = common_genes.intersection(s.var.index)
for i in range(len(slices)):
slices[i] = slices[i][:, common_genes]
logging.info(
"Filtered all slices for common genes. There are "
+ str(len(common_genes))
+ " common genes."
return [slice[:, common_genes] for slice in slices], common_genes


def compute_slice_weights(slice_weights, pis, slices, device):
return sum(
[
slice_weights[i]
* torch.matmul(pis[i], to_dense_array(slices[i].X).to(device))
for i in range(len(slices))
]
)


Expand Down Expand Up @@ -333,41 +315,6 @@ def match_spots_using_spatial_heuristic(
return pi


def kl_divergence_backend(a_exp_dissim, b_exp_dissim):
"""
Calculates the Kullback-Leibler divergence between two distributions.

Parameters
----------
a_exp_dissim : torch.Tensor
A tensor representing the first probability distribution.

b_exp_dissim : torch.Tensor
A tensor representing the second probability distribution.

Returns
-------
divergence : np.ndarray
A tensor containing the Kullback-Leibler divergence for each sample.
"""
assert (
a_exp_dissim.shape[1] == b_exp_dissim.shape[1]
), "X and Y do not have the same number of features."

nx = ot.backend.get_backend(a_exp_dissim, b_exp_dissim)

a_exp_dissim = a_exp_dissim / nx.sum(a_exp_dissim, axis=1, keepdims=True)
b_exp_dissim = b_exp_dissim / nx.sum(b_exp_dissim, axis=1, keepdims=True)
a_log_exp_dissim = nx.log(a_exp_dissim)
b_log_exp_dissim = nx.log(b_exp_dissim)
a_weighted_dissim_sum = nx.einsum("ij,ij->i", a_exp_dissim, a_log_exp_dissim)
a_weighted_dissim_sum = nx.reshape(
a_weighted_dissim_sum, (1, a_weighted_dissim_sum.shape[0])
)
divergence = a_weighted_dissim_sum.T - nx.dot(a_exp_dissim, b_log_exp_dissim.T)
return nx.to_numpy(divergence)


def dissimilarity_metric(which, a_slice, b_slice, a_exp_dissim, b_exp_dissim, **kwargs):
"""
Computes a dissimilarity matrix between two distribution using a specified
Expand Down Expand Up @@ -408,7 +355,9 @@ def dissimilarity_metric(which, a_slice, b_slice, a_exp_dissim, b_exp_dissim, **
case "gkl":
a_exp_dissim = a_exp_dissim + 0.01
b_exp_dissim = b_exp_dissim + 0.01
exp_dissim_matrix = generalized_kl_divergence(a_exp_dissim, b_exp_dissim)
exp_dissim_matrix = kl_divergence(
a_exp_dissim, b_exp_dissim, is_generalized=True
)
exp_dissim_matrix /= exp_dissim_matrix[exp_dissim_matrix > 0].max()
exp_dissim_matrix *= 10
return exp_dissim_matrix
Expand Down
1 change: 0 additions & 1 deletion src/paste3/model_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,6 @@ def select_overlap_fraction(sliceA, sliceB, alpha=0.1, show_plot=True, numIterma
exp_dissim_matrix=M,
alpha=alpha,
norm=True,
return_obj=True,
numItermax=numItermax,
maxIter=numItermax,
)
Expand Down
Loading
Loading