Skip to content

Commit

Permalink
make palantir trsults fully reproducible
Browse files Browse the repository at this point in the history
  • Loading branch information
katosh committed Jun 20, 2023
1 parent 94c4389 commit f4970f9
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 7 deletions.
19 changes: 13 additions & 6 deletions src/palantir/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import pandas as pd
import networkx as nx
import time
import random
import copy

from sklearn.metrics import pairwise_distances
Expand Down Expand Up @@ -45,6 +44,7 @@ def run_palantir(
entropy_key: str = "palantir_entropy",
fate_prob_key: str = "palantir_fate_probabilities",
waypoints_key: str = "palantir_waypoints",
seed: int = 20,
) -> Optional[PResults]:
"""
Executes the Palantir algorithm to derive pseudotemporal ordering of cells, their fate probabilities, and
Expand Down Expand Up @@ -81,6 +81,8 @@ def run_palantir(
Column names of the probability matrix are stored in the AnnData's uns[fate_prob_key + "_columns"].
waypoints_key : str, optional
Key to store the waypoints in uns of the AnnData object. Default is 'palantir_waypoints'.
seed : int, optional
The seed for the random number generator used in waypoint sampling. Default is 20.
Returns
-------
Expand Down Expand Up @@ -126,7 +128,7 @@ def run_palantir(

# Append start cell
if isinstance(num_waypoints, int):
waypoints = _max_min_sampling(data_df, num_waypoints)
waypoints = _max_min_sampling(data_df, num_waypoints, seed)
else:
waypoints = num_waypoints
waypoints = waypoints.union(dm_boundaries)
Expand Down Expand Up @@ -175,18 +177,20 @@ def run_palantir(
return pr_res


def _max_min_sampling(data, num_waypoints):
def _max_min_sampling(data, num_waypoints, seed=None):
"""Function for max min sampling of waypoints
:param data: Data matrix along which to sample the waypoints,
usually diffusion components
:param num_waypoints: Number of waypoints to sample
:param num_jobs: Number of jobs for parallel processing
:param seed: Random number generator seed to find initial guess.
:return: pandas Series reprenting the sampled waypoints
"""

waypoint_set = list()
no_iterations = int((num_waypoints) / data.shape[1])
if seed is not None:
np.random.seed(seed)

# Sample along each component
N = data.shape[0]
Expand All @@ -195,7 +199,9 @@ def _max_min_sampling(data, num_waypoints):
vec = np.ravel(data[ind])

# Random initialzlation
iter_set = random.sample(range(N), 1)
iter_set = [
np.random.randint(N),
]

# Distances along the component
dists = np.zeros([N, no_iterations])
Expand Down Expand Up @@ -312,6 +318,7 @@ def identify_terminal_states(
num_waypoints=1200,
n_jobs=-1,
max_iterations=25,
seed=20,
):

# Scale components
Expand All @@ -330,7 +337,7 @@ def identify_terminal_states(

# Sample waypoints
# Append start cell
waypoints = _max_min_sampling(data, num_waypoints)
waypoints = _max_min_sampling(data, num_waypoints, seed)
waypoints = waypoints.union(dm_boundaries)
waypoints = pd.Index(waypoints.difference([start_cell]).unique())

Expand Down
2 changes: 1 addition & 1 deletion src/palantir/presults.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def compute_gene_trends(
pseudo_time = ad.obs[pseudo_time_key].values
masks = ad.obsm[masks_key]
branches = ad.uns[masks_key + "_columns"]

if lineages is not None:
for lin in lineages:
if lin not in branches:
Expand Down

0 comments on commit f4970f9

Please # to comment.