Skip to content

Commit 02d4771

Browse files
jeromekellehermergify-bot
authored and
mergify-bot
committed
Implement total_branch_length in C
Closes #1794 Benchmarks: Before: big_tree 0.26075993799986463 many_trees 0.002319710470019345 After: big_tree 0.004731306999929075 many_trees 2.1276080024108523e-05 So 2-3 orders of magnitude (probably more compared to previous release, since this using faster node iteration)
1 parent 2e7428b commit 02d4771

File tree

6 files changed

+125
-1
lines changed

6 files changed

+125
-1
lines changed

c/tests/test_trees.c

+39
Original file line numberDiff line numberDiff line change
@@ -4329,6 +4329,44 @@ test_single_tree_is_descendant(void)
43294329
tsk_treeseq_free(&ts);
43304330
}
43314331

4332+
static void
4333+
test_single_tree_total_branch_length(void)
4334+
{
4335+
int ret;
4336+
tsk_treeseq_t ts;
4337+
tsk_tree_t tree;
4338+
double length;
4339+
4340+
tsk_treeseq_from_text(&ts, 1, single_tree_ex_nodes, single_tree_ex_edges, NULL, NULL,
4341+
NULL, NULL, NULL, 0);
4342+
ret = tsk_tree_init(&tree, &ts, 0);
4343+
CU_ASSERT_EQUAL_FATAL(ret, 0);
4344+
ret = tsk_tree_first(&tree);
4345+
CU_ASSERT_EQUAL_FATAL(ret, 1);
4346+
4347+
CU_ASSERT_EQUAL_FATAL(tsk_tree_get_total_branch_length(&tree, TSK_NULL, &length), 0);
4348+
CU_ASSERT_EQUAL_FATAL(length, 9);
4349+
CU_ASSERT_EQUAL_FATAL(tsk_tree_get_total_branch_length(&tree, 7, &length), 0);
4350+
CU_ASSERT_EQUAL_FATAL(length, 9);
4351+
CU_ASSERT_EQUAL_FATAL(
4352+
tsk_tree_get_total_branch_length(&tree, tree.virtual_root, &length), 0);
4353+
CU_ASSERT_EQUAL_FATAL(length, 9);
4354+
CU_ASSERT_EQUAL_FATAL(tsk_tree_get_total_branch_length(&tree, 4, &length), 0);
4355+
CU_ASSERT_EQUAL_FATAL(length, 2);
4356+
CU_ASSERT_EQUAL_FATAL(tsk_tree_get_total_branch_length(&tree, 0, &length), 0);
4357+
CU_ASSERT_EQUAL_FATAL(length, 0);
4358+
CU_ASSERT_EQUAL_FATAL(tsk_tree_get_total_branch_length(&tree, 5, &length), 0);
4359+
CU_ASSERT_EQUAL_FATAL(length, 4);
4360+
4361+
CU_ASSERT_EQUAL_FATAL(tsk_tree_get_total_branch_length(&tree, -2, &length),
4362+
TSK_ERR_NODE_OUT_OF_BOUNDS);
4363+
CU_ASSERT_EQUAL_FATAL(
4364+
tsk_tree_get_total_branch_length(&tree, 8, &length), TSK_ERR_NODE_OUT_OF_BOUNDS);
4365+
4366+
tsk_tree_free(&tree);
4367+
tsk_treeseq_free(&ts);
4368+
}
4369+
43324370
static void
43334371
test_single_tree_map_mutations(void)
43344372
{
@@ -6605,6 +6643,7 @@ main(int argc, char **argv)
66056643
{ "test_single_tree_compute_mutation_times",
66066644
test_single_tree_compute_mutation_times },
66076645
{ "test_single_tree_is_descendant", test_single_tree_is_descendant },
6646+
{ "test_single_tree_total_branch_length", test_single_tree_total_branch_length },
66086647
{ "test_single_tree_map_mutations", test_single_tree_map_mutations },
66096648
{ "test_single_tree_map_mutations_internal_samples",
66106649
test_single_tree_map_mutations_internal_samples },

c/tskit/trees.c

+34
Original file line numberDiff line numberDiff line change
@@ -3631,6 +3631,40 @@ tsk_tree_get_time(const tsk_tree_t *self, tsk_id_t u, double *t)
36313631
return ret;
36323632
}
36333633

3634+
int
3635+
tsk_tree_get_total_branch_length(const tsk_tree_t *self, tsk_id_t node, double *ret_tbl)
3636+
{
3637+
int ret = 0;
3638+
tsk_size_t j, num_nodes;
3639+
tsk_id_t u, v;
3640+
const tsk_id_t *restrict parent = self->parent;
3641+
const double *restrict time = self->tree_sequence->tables->nodes.time;
3642+
tsk_id_t *nodes = tsk_malloc(tsk_tree_get_size_bound(self) * sizeof(*nodes));
3643+
double sum = 0;
3644+
3645+
if (nodes == NULL) {
3646+
ret = TSK_ERR_NO_MEMORY;
3647+
goto out;
3648+
}
3649+
ret = tsk_tree_preorder(self, node, nodes, &num_nodes);
3650+
if (ret != 0) {
3651+
goto out;
3652+
}
3653+
/* We always skip the first node because we don't return the branch length
3654+
* over the input node. */
3655+
for (j = 1; j < num_nodes; j++) {
3656+
u = nodes[j];
3657+
v = parent[u];
3658+
if (v != TSK_NULL) {
3659+
sum += time[v] - time[u];
3660+
}
3661+
}
3662+
*ret_tbl = sum;
3663+
out:
3664+
tsk_safe_free(nodes);
3665+
return ret;
3666+
}
3667+
36343668
int TSK_WARN_UNUSED
36353669
tsk_tree_get_sites(
36363670
const tsk_tree_t *self, const tsk_site_t **sites, tsk_size_t *sites_length)

c/tskit/trees.h

+25
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,31 @@ be greater than or equal to ``num_nodes``.
436436
*/
437437
tsk_size_t tsk_tree_get_size_bound(const tsk_tree_t *self);
438438

439+
/**
440+
@brief Returns the sum of the lengths of all branches reachable from
441+
the specified node, or from all roots if node=TSK_NULL.
442+
443+
@rst
444+
Return the total branch length in a particular subtree or of the
445+
entire tree. If the specified node is TSK_NULL (or the virtual
446+
root) the sum of the lengths of all branches reachable from roots
447+
is returned. Branch length is defined as difference between the time
448+
of a node and its parent. The branch length of a root is zero.
449+
450+
Note that if the specified node is internal its branch length is
451+
*not* included, so that, e.g., the total branch length of a
452+
leaf node is zero.
453+
@endrst
454+
455+
@param self A pointer to a tsk_tree_t object.
456+
@param node The tree node to compute branch length or TSK_NULL to return the
457+
total branch length of the tree.
458+
@param ret_tbl A double pointer to store the returned total branch length.
459+
@return 0 on success or a negative value on failure.
460+
*/
461+
int tsk_tree_get_total_branch_length(
462+
const tsk_tree_t *self, tsk_id_t node, double *ret_tbl);
463+
439464
/** @} */
440465

441466
int tsk_tree_set_root_threshold(tsk_tree_t *self, tsk_size_t root_threshold);

python/CHANGELOG.rst

+2
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@
6060
Roughly a 10X performance increase for "preorder", "postorder", "timeasc"
6161
and "timedesc" (:user:`jeromekelleher`, :pr:`1704`).
6262

63+
- Substantial performance improvement for ``Tree.total_branch_length``
64+
(:user:`jeromekelleher`, :issue:`1794` :pr:`1799`)
6365

6466
--------------------
6567
[0.3.7] - 2021-07-08

python/_tskitmodule.c

+24
Original file line numberDiff line numberDiff line change
@@ -9433,6 +9433,26 @@ Tree_get_num_edges(Tree *self)
94339433
return ret;
94349434
}
94359435

9436+
static PyObject *
9437+
Tree_get_total_branch_length(Tree *self)
9438+
{
9439+
PyObject *ret = NULL;
9440+
double length;
9441+
int err;
9442+
9443+
if (Tree_check_state(self) != 0) {
9444+
goto out;
9445+
}
9446+
err = tsk_tree_get_total_branch_length(self->tree, TSK_NULL, &length);
9447+
if (err != 0) {
9448+
handle_library_error(err);
9449+
goto out;
9450+
}
9451+
ret = Py_BuildValue("d", length);
9452+
out:
9453+
return ret;
9454+
}
9455+
94369456
static PyObject *
94379457
Tree_get_index(Tree *self)
94389458
{
@@ -10363,6 +10383,10 @@ static PyMethodDef Tree_methods[] = {
1036310383
.ml_meth = (PyCFunction) Tree_get_num_edges,
1036410384
.ml_flags = METH_NOARGS,
1036510385
.ml_doc = "Returns the number of branches in this tree." },
10386+
{ .ml_name = "get_total_branch_length",
10387+
.ml_meth = (PyCFunction) Tree_get_total_branch_length,
10388+
.ml_flags = METH_NOARGS,
10389+
.ml_doc = "Returns the sum of the branch lengths reachable from roots" },
1036610390
{ .ml_name = "get_left",
1036710391
.ml_meth = (PyCFunction) Tree_get_left,
1036810392
.ml_flags = METH_NOARGS,

python/tskit/trees.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -994,7 +994,7 @@ def total_branch_length(self):
994994
:return: The sum of lengths of branches in this tree.
995995
:rtype: float
996996
"""
997-
return sum(self.branch_length(u) for u in self.nodes())
997+
return self._ll_tree.get_total_branch_length()
998998

999999
def get_mrca(self, u, v):
10001000
# Deprecated alias for mrca

0 commit comments

Comments
 (0)