Skip to content

Commit cf8829e

Browse files
committed
adding time slicing to ts.samples()
1 parent 936bf86 commit cf8829e

File tree

3 files changed

+202
-13
lines changed

3 files changed

+202
-13
lines changed

python/CHANGELOG.rst

+3
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
- Add `__setitem__` to all tables allowing single rows to be updated. For example
2323
`tables.nodes[0] = tables.nodes[0].replace(flags=tskit.NODE_IS_SAMPLE)`
2424
(:user:`jeromekelleher`, :user:`benjeffery`, :issue:`1545`, :pr:`1600`).
25+
- Added a new parameter ``time`` to ``TreeSequence.samples()`` allowing to select
26+
samples at a specific time point or time interval.
27+
(:user:`mufernando`, :user:`petrelharp`, :issue:`1692`, :pr:`1700`)
2528

2629
--------------------
2730
[0.3.7] - 2021-07-08

python/tests/test_highlevel.py

+166-3
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,23 @@ def get_mrca(pi, x, y):
346346
return mrca
347347

348348

349+
def get_samples(ts, time=None, population=None):
350+
samples = []
351+
for node in ts.nodes():
352+
keep = bool(node.is_sample())
353+
if time is not None:
354+
if isinstance(time, (int, float)):
355+
keep &= np.isclose(node.time, time)
356+
if isinstance(time, (tuple, list)):
357+
keep &= node.time >= time[0]
358+
keep &= node.time < time[1]
359+
if population is not None:
360+
keep &= node.population == population
361+
if keep:
362+
samples.append(node.id)
363+
return np.array(samples)
364+
365+
349366
class TestMRCACalculator:
350367
"""
351368
Class to test the Schieber-Vishkin algorithm.
@@ -509,11 +526,14 @@ class TestNumpySamples:
509526
various methods.
510527
"""
511528

512-
def get_tree_sequence(self, num_demes=4):
513-
n = 40
529+
def get_tree_sequence(self, num_demes=4, times=None, n=40):
530+
if times is None:
531+
times = [0]
514532
return msprime.simulate(
515533
samples=[
516-
msprime.Sample(time=0, population=j % num_demes) for j in range(n)
534+
msprime.Sample(time=t, population=j % num_demes)
535+
for j in range(n)
536+
for t in times
517537
],
518538
population_configurations=[
519539
msprime.PopulationConfiguration() for _ in range(num_demes)
@@ -541,6 +561,149 @@ def test_samples(self):
541561
]
542562
assert total == ts.num_samples
543563

564+
def test_samples_time(self):
565+
times = [0, 0.1, 1 / 3, 1 / 4, 5 / 7]
566+
ts = self.get_tree_sequence(num_demes=2, n=20, times=times)
567+
for time in times:
568+
assert np.array_equal(get_samples(ts, time=time), ts.samples(time=time))
569+
for population in (None, 0):
570+
assert np.array_equal(
571+
get_samples(ts, time=time, population=population),
572+
ts.samples(time=time, population=population),
573+
)
574+
575+
def test_samples_time_interval(self):
576+
rng = np.random.default_rng(seed=931)
577+
time_intervals = [
578+
[0, 0.1],
579+
(0, 1 / 3),
580+
np.array([1 / 4, 2 / 3]),
581+
(0.345, 5 / 7),
582+
(-1, 1),
583+
]
584+
for time_interval in time_intervals:
585+
times = rng.uniform(low=time_interval[0], high=time_interval[1], size=20)
586+
ts = self.get_tree_sequence(num_demes=2, n=1, times=times)
587+
assert np.array_equal(
588+
get_samples(ts, time=time_interval),
589+
ts.samples(time=time_interval),
590+
)
591+
for population in (None, 0):
592+
assert np.array_equal(
593+
get_samples(ts, time=time_interval, population=population),
594+
ts.samples(time=time_interval, population=population),
595+
)
596+
597+
def test_samples_example(self):
598+
tables = tskit.TableCollection(sequence_length=10)
599+
time = [np.array(0), 0, np.array([1]), 1, 1, 3, 3.00001, 3.0 - 0.0001, 1 / 3]
600+
pops = [1, 3, 1, 2, 1, 1, 1, 3, 1]
601+
for _ in range(max(pops) + 1):
602+
tables.populations.add_row()
603+
for t, p in zip(time, pops):
604+
tables.nodes.add_row(
605+
flags=tskit.NODE_IS_SAMPLE,
606+
time=t,
607+
population=p,
608+
)
609+
# add not-samples also
610+
for t, p in zip(time, pops):
611+
tables.nodes.add_row(
612+
flags=0,
613+
time=t,
614+
population=p,
615+
)
616+
ts = tables.tree_sequence()
617+
assert np.array_equal(
618+
ts.samples(),
619+
np.arange(len(time)),
620+
)
621+
assert np.array_equal(
622+
ts.samples(time=[0, np.inf]),
623+
np.arange(len(time)),
624+
)
625+
assert np.array_equal(
626+
ts.samples(time=0),
627+
[0, 1],
628+
)
629+
# default tolerance is 1e-5
630+
assert np.array_equal(
631+
ts.samples(time=0.3333333),
632+
[8],
633+
)
634+
assert np.array_equal(
635+
ts.samples(time=3),
636+
[5, 6],
637+
)
638+
assert np.array_equal(
639+
ts.samples(time=1),
640+
[2, 3, 4],
641+
)
642+
assert np.array_equal(
643+
ts.samples(time=1, population=2),
644+
[3],
645+
)
646+
assert np.array_equal(
647+
ts.samples(population=0),
648+
[],
649+
)
650+
assert np.array_equal(
651+
ts.samples(population=1),
652+
[0, 2, 4, 5, 6, 8],
653+
)
654+
assert np.array_equal(
655+
ts.samples(population=2),
656+
[3],
657+
)
658+
assert np.array_equal(
659+
ts.samples(time=[0, 3]),
660+
[0, 1, 2, 3, 4, 7, 8],
661+
)
662+
# note tuple instead of array
663+
assert np.array_equal(
664+
ts.samples(time=(1, 3)),
665+
[2, 3, 4, 7],
666+
)
667+
assert np.array_equal(
668+
ts.samples(time=[0, 3], population=1),
669+
[0, 2, 4, 8],
670+
)
671+
assert np.array_equal(
672+
ts.samples(time=[0.333333, 3]),
673+
[2, 3, 4, 7, 8],
674+
)
675+
assert np.array_equal(
676+
ts.samples(time=[100, np.inf]),
677+
[],
678+
)
679+
assert np.array_equal(
680+
ts.samples(time=-1),
681+
[],
682+
)
683+
assert np.array_equal(
684+
ts.samples(time=[-100, 100]),
685+
np.arange(len(time)),
686+
)
687+
assert np.array_equal(
688+
ts.samples(time=[-100, -1]),
689+
[],
690+
)
691+
692+
def test_samples_time_errors(self):
693+
ts = self.get_tree_sequence(4)
694+
# error incorrect types
695+
with pytest.raises(ValueError):
696+
ts.samples(time="s")
697+
with pytest.raises(ValueError):
698+
ts.samples(time=[])
699+
with pytest.raises(ValueError):
700+
ts.samples(time=np.array([1, 2, 3]))
701+
with pytest.raises(ValueError):
702+
ts.samples(time=(1, 2, 3))
703+
# error using min and max switched
704+
with pytest.raises(ValueError):
705+
ts.samples(time=(2.4, 1))
706+
544707
def test_genotype_matrix_indexing(self):
545708
num_demes = 4
546709
ts = self.get_tree_sequence(num_demes)

python/tskit/trees.py

+33-10
Original file line numberDiff line numberDiff line change
@@ -4800,15 +4800,21 @@ def get_samples(self, population_id=None):
48004800
# Deprecated alias for samples()
48014801
return self.samples(population_id)
48024802

4803-
def samples(self, population=None, population_id=None):
4804-
"""
4805-
Returns an array of the sample node IDs in this tree sequence. If the
4806-
``population`` parameter is specified, only return sample IDs from that
4807-
population.
4808-
4809-
:param int population: The population of interest. If None,
4810-
return all samples.
4803+
def samples(self, population=None, population_id=None, time=None):
4804+
"""
4805+
Returns an array of the sample node IDs in this tree sequence. If
4806+
`population` is specified, only return sample IDs from that population.
4807+
It is also possible to restrict samples by time using the parameter
4808+
`time`. If `time` is a numeric value, only return sample IDs whose node
4809+
time is approximately equal to the specified time. If `time` is a pair
4810+
of values of the form `(min_time, max_time)`, only return sample IDs
4811+
whose node time `t` is in this interval such that `min_time <= t < max_time`.
4812+
4813+
:param int population: The population of interest. If None, do not
4814+
filter samples by population.
48114815
:param int population_id: Deprecated alias for ``population``.
4816+
:param float,tuple time: The time or time interval of interest. If
4817+
None, do not filter samples by time.
48124818
:return: A numpy array of the node IDs for the samples of interest,
48134819
listed in numerical order.
48144820
:rtype: numpy.ndarray (dtype=np.int32)
@@ -4820,10 +4826,27 @@ def samples(self, population=None, population_id=None):
48204826
if population_id is not None:
48214827
population = population_id
48224828
samples = self._ll_tree_sequence.get_samples()
4829+
keep = np.full(shape=samples.shape, fill_value=True)
48234830
if population is not None:
48244831
sample_population = self.tables.nodes.population[samples]
4825-
samples = samples[sample_population == population]
4826-
return samples
4832+
keep = np.logical_and(keep, sample_population == population)
4833+
if time is not None:
4834+
# ndmin is set so that scalars are converted into 1d arrays
4835+
time = np.array(time, ndmin=1, dtype=float)
4836+
sample_times = self.tables.nodes.time[samples]
4837+
if time.shape == (1,):
4838+
keep = np.logical_and(keep, np.isclose(sample_times, time))
4839+
elif time.shape == (2,):
4840+
if time[1] <= time[0]:
4841+
raise ValueError("time_interval max is less than or equal to min.")
4842+
keep = np.logical_and(keep, sample_times >= time[0])
4843+
keep = np.logical_and(keep, sample_times < time[1])
4844+
else:
4845+
raise ValueError(
4846+
"time must be either a single value or a pair of values "
4847+
"(min_time, max_time)."
4848+
)
4849+
return samples[keep]
48274850

48284851
def write_fasta(self, output, sequence_ids=None, wrap_width=60):
48294852
""

0 commit comments

Comments
 (0)