@@ -346,6 +346,23 @@ def get_mrca(pi, x, y):
346
346
return mrca
347
347
348
348
349
+ def get_samples (ts , time = None , population = None ):
350
+ samples = []
351
+ for node in ts .nodes ():
352
+ keep = bool (node .is_sample ())
353
+ if time is not None :
354
+ if isinstance (time , (int , float )):
355
+ keep &= np .isclose (node .time , time )
356
+ if isinstance (time , (tuple , list )):
357
+ keep &= node .time >= time [0 ]
358
+ keep &= node .time < time [1 ]
359
+ if population is not None :
360
+ keep &= node .population == population
361
+ if keep :
362
+ samples .append (node .id )
363
+ return np .array (samples )
364
+
365
+
349
366
class TestMRCACalculator :
350
367
"""
351
368
Class to test the Schieber-Vishkin algorithm.
@@ -509,11 +526,14 @@ class TestNumpySamples:
509
526
various methods.
510
527
"""
511
528
512
- def get_tree_sequence (self , num_demes = 4 ):
513
- n = 40
529
+ def get_tree_sequence (self , num_demes = 4 , times = None , n = 40 ):
530
+ if times is None :
531
+ times = [0 ]
514
532
return msprime .simulate (
515
533
samples = [
516
- msprime .Sample (time = 0 , population = j % num_demes ) for j in range (n )
534
+ msprime .Sample (time = t , population = j % num_demes )
535
+ for j in range (n )
536
+ for t in times
517
537
],
518
538
population_configurations = [
519
539
msprime .PopulationConfiguration () for _ in range (num_demes )
@@ -541,6 +561,149 @@ def test_samples(self):
541
561
]
542
562
assert total == ts .num_samples
543
563
564
+ def test_samples_time (self ):
565
+ times = [0 , 0.1 , 1 / 3 , 1 / 4 , 5 / 7 ]
566
+ ts = self .get_tree_sequence (num_demes = 2 , n = 20 , times = times )
567
+ for time in times :
568
+ assert np .array_equal (get_samples (ts , time = time ), ts .samples (time = time ))
569
+ for population in (None , 0 ):
570
+ assert np .array_equal (
571
+ get_samples (ts , time = time , population = population ),
572
+ ts .samples (time = time , population = population ),
573
+ )
574
+
575
+ def test_samples_time_interval (self ):
576
+ rng = np .random .default_rng (seed = 931 )
577
+ time_intervals = [
578
+ [0 , 0.1 ],
579
+ (0 , 1 / 3 ),
580
+ np .array ([1 / 4 , 2 / 3 ]),
581
+ (0.345 , 5 / 7 ),
582
+ (- 1 , 1 ),
583
+ ]
584
+ for time_interval in time_intervals :
585
+ times = rng .uniform (low = time_interval [0 ], high = time_interval [1 ], size = 20 )
586
+ ts = self .get_tree_sequence (num_demes = 2 , n = 1 , times = times )
587
+ assert np .array_equal (
588
+ get_samples (ts , time = time_interval ),
589
+ ts .samples (time = time_interval ),
590
+ )
591
+ for population in (None , 0 ):
592
+ assert np .array_equal (
593
+ get_samples (ts , time = time_interval , population = population ),
594
+ ts .samples (time = time_interval , population = population ),
595
+ )
596
+
597
+ def test_samples_example (self ):
598
+ tables = tskit .TableCollection (sequence_length = 10 )
599
+ time = [np .array (0 ), 0 , np .array ([1 ]), 1 , 1 , 3 , 3.00001 , 3.0 - 0.0001 , 1 / 3 ]
600
+ pops = [1 , 3 , 1 , 2 , 1 , 1 , 1 , 3 , 1 ]
601
+ for _ in range (max (pops ) + 1 ):
602
+ tables .populations .add_row ()
603
+ for t , p in zip (time , pops ):
604
+ tables .nodes .add_row (
605
+ flags = tskit .NODE_IS_SAMPLE ,
606
+ time = t ,
607
+ population = p ,
608
+ )
609
+ # add not-samples also
610
+ for t , p in zip (time , pops ):
611
+ tables .nodes .add_row (
612
+ flags = 0 ,
613
+ time = t ,
614
+ population = p ,
615
+ )
616
+ ts = tables .tree_sequence ()
617
+ assert np .array_equal (
618
+ ts .samples (),
619
+ np .arange (len (time )),
620
+ )
621
+ assert np .array_equal (
622
+ ts .samples (time = [0 , np .inf ]),
623
+ np .arange (len (time )),
624
+ )
625
+ assert np .array_equal (
626
+ ts .samples (time = 0 ),
627
+ [0 , 1 ],
628
+ )
629
+ # default tolerance is 1e-5
630
+ assert np .array_equal (
631
+ ts .samples (time = 0.3333333 ),
632
+ [8 ],
633
+ )
634
+ assert np .array_equal (
635
+ ts .samples (time = 3 ),
636
+ [5 , 6 ],
637
+ )
638
+ assert np .array_equal (
639
+ ts .samples (time = 1 ),
640
+ [2 , 3 , 4 ],
641
+ )
642
+ assert np .array_equal (
643
+ ts .samples (time = 1 , population = 2 ),
644
+ [3 ],
645
+ )
646
+ assert np .array_equal (
647
+ ts .samples (population = 0 ),
648
+ [],
649
+ )
650
+ assert np .array_equal (
651
+ ts .samples (population = 1 ),
652
+ [0 , 2 , 4 , 5 , 6 , 8 ],
653
+ )
654
+ assert np .array_equal (
655
+ ts .samples (population = 2 ),
656
+ [3 ],
657
+ )
658
+ assert np .array_equal (
659
+ ts .samples (time = [0 , 3 ]),
660
+ [0 , 1 , 2 , 3 , 4 , 7 , 8 ],
661
+ )
662
+ # note tuple instead of array
663
+ assert np .array_equal (
664
+ ts .samples (time = (1 , 3 )),
665
+ [2 , 3 , 4 , 7 ],
666
+ )
667
+ assert np .array_equal (
668
+ ts .samples (time = [0 , 3 ], population = 1 ),
669
+ [0 , 2 , 4 , 8 ],
670
+ )
671
+ assert np .array_equal (
672
+ ts .samples (time = [0.333333 , 3 ]),
673
+ [2 , 3 , 4 , 7 , 8 ],
674
+ )
675
+ assert np .array_equal (
676
+ ts .samples (time = [100 , np .inf ]),
677
+ [],
678
+ )
679
+ assert np .array_equal (
680
+ ts .samples (time = - 1 ),
681
+ [],
682
+ )
683
+ assert np .array_equal (
684
+ ts .samples (time = [- 100 , 100 ]),
685
+ np .arange (len (time )),
686
+ )
687
+ assert np .array_equal (
688
+ ts .samples (time = [- 100 , - 1 ]),
689
+ [],
690
+ )
691
+
692
+ def test_samples_time_errors (self ):
693
+ ts = self .get_tree_sequence (4 )
694
+ # error incorrect types
695
+ with pytest .raises (ValueError ):
696
+ ts .samples (time = "s" )
697
+ with pytest .raises (ValueError ):
698
+ ts .samples (time = [])
699
+ with pytest .raises (ValueError ):
700
+ ts .samples (time = np .array ([1 , 2 , 3 ]))
701
+ with pytest .raises (ValueError ):
702
+ ts .samples (time = (1 , 2 , 3 ))
703
+ # error using min and max switched
704
+ with pytest .raises (ValueError ):
705
+ ts .samples (time = (2.4 , 1 ))
706
+
544
707
def test_genotype_matrix_indexing (self ):
545
708
num_demes = 4
546
709
ts = self .get_tree_sequence (num_demes )
0 commit comments