Skip to content

Commit 0470939

Browse files
committed
Add descendant and ancestral sub-ARGs
1 parent fdd6361 commit 0470939

File tree

2 files changed

+272
-1
lines changed

2 files changed

+272
-1
lines changed

docs/numba.md

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,169 @@ np.testing.assert_array_almost_equal(numba_result, tskit_result, decimal=10)
364364
print("Results match!")
365365
```
366366

367+
### Example - ARG descendant and ancestral edges calculation
368+
369+
As we have `child_index` and `parent_index`, we can efficiently find both descendant and ancestral sub-ARGs
370+
for a given node. This first example shows how to find all edges in the ARG that are descendants of a given node. It returns a boolean mask indicating which edges are part of the sub-ARG rooted at the specified node:
371+
372+
```{code-cell} python
373+
@numba.njit
374+
def descendant_edges(numba_ts, u):
375+
"""
376+
Returns a boolean mask for edges that are descendants of node u.
377+
"""
378+
edge_mask = np.zeros(numba_ts.num_edges, dtype=np.bool_)
379+
child_index = numba_ts.child_index()
380+
edges_left = numba_ts.edges_left
381+
edges_right = numba_ts.edges_right
382+
edges_child = numba_ts.edges_child
383+
384+
# The stack stores (node_id, left_coord, right_coord)
385+
stack = [(u, 0.0, numba_ts.sequence_length)]
386+
387+
while len(stack) > 0:
388+
node, left, right = stack.pop()
389+
390+
# Find all edges where 'node' is the parent
391+
start, stop = child_index[node]
392+
for e in range(start, stop):
393+
e_left = edges_left[e]
394+
e_right = edges_right[e]
395+
396+
# Check for genomic interval overlap
397+
if e_right > left and right > e_left:
398+
# This edge is part of the sub-ARG
399+
edge_mask[e] = True
400+
401+
# Calculate the intersection for the next traversal step
402+
inter_left = max(e_left, left)
403+
inter_right = min(e_right, right)
404+
e_child = edges_child[e]
405+
stack.append((e_child, inter_left, inter_right))
406+
407+
return edge_mask
408+
```
409+
410+
```{code-cell} python
411+
# Find descendant edges for a high-numbered node (likely near root)
412+
test_node = max(0, numba_ts.num_nodes - 5)
413+
edge_mask = descendant_edges(numba_ts, test_node)
414+
415+
# Show which edges are descendants
416+
descendant_edge_ids = np.where(edge_mask)[0]
417+
print(f"Edges descended from node {test_node}: {descendant_edge_ids[:10]}...")
418+
print(f"Total descendant edges: {np.sum(edge_mask)}")
419+
```
420+
421+
In the other direction, we can similarly find the sub-ARG that is ancestral to a given node:
422+
423+
```{code-cell} python
424+
@numba.njit
425+
def ancestral_edges(numba_ts, u):
426+
"""
427+
Returns a boolean mask for edges that are ancestors of node u.
428+
"""
429+
edge_mask = np.zeros(numba_ts.num_edges, dtype=np.bool_)
430+
parent_index = numba_ts.parent_index()
431+
edges_left = numba_ts.edges_left
432+
edges_right = numba_ts.edges_right
433+
edges_parent = numba_ts.edges_parent
434+
435+
# The stack stores (node_id, left_coord, right_coord)
436+
stack = [(u, 0.0, numba_ts.sequence_length)]
437+
438+
while len(stack) > 0:
439+
node, left, right = stack.pop()
440+
441+
# Find all edges where 'node' is the child
442+
start, stop = parent_index.index_range[node]
443+
for i in range(start, stop):
444+
e = parent_index.edge_index[i]
445+
e_left = edges_left[e]
446+
e_right = edges_right[e]
447+
448+
# Check for genomic interval overlap
449+
if e_right > left and right > e_left:
450+
# This edge is part of the sub-ARG
451+
edge_mask[e] = True
452+
453+
# Calculate the intersection for the next traversal step
454+
inter_left = max(e_left, left)
455+
inter_right = min(e_right, right)
456+
e_parent = edges_parent[e]
457+
stack.append((e_parent, inter_left, inter_right))
458+
459+
return edge_mask
460+
```
461+
462+
```{code-cell} python
463+
# Find ancestral edges for a sample node (low-numbered nodes are usually samples)
464+
test_node = min(5, numba_ts.num_nodes - 1)
465+
edge_mask = ancestral_edges(numba_ts, test_node)
466+
467+
# Show which edges are ancestors
468+
ancestral_edge_ids = np.where(edge_mask)[0]
469+
print(f"Edges ancestral to node {test_node}: {ancestral_edge_ids[:10]}...")
470+
print(f"Total ancestral edges: {np.sum(edge_mask)}")
471+
```
472+
473+
```{code-cell} python
474+
:tags: [hide-cell]
475+
# Warm up the JIT for both functions
476+
_ = descendant_edges(numba_ts, 0)
477+
_ = ancestral_edges(numba_ts, 0)
478+
```
479+
480+
Comparing performance with using the tskit Python API shows significant speedup:
481+
482+
```{code-cell} python
483+
def descendant_edges_tskit(ts, start_node):
484+
D = np.zeros(ts.num_edges, dtype=bool)
485+
for tree in ts.trees():
486+
for v in tree.preorder(start_node):
487+
if v != start_node:
488+
D[tree.edge(v)] = True
489+
return D
490+
491+
def ancestral_edges_tskit(ts, start_node):
492+
A = np.zeros(ts.num_edges, dtype=bool)
493+
for tree in ts.trees():
494+
curr_node = start_node
495+
parent = tree.parent(curr_node)
496+
while parent != tskit.NULL:
497+
edge_id = tree.edge(curr_node)
498+
A[edge_id] = True
499+
curr_node = parent
500+
parent = tree.parent(curr_node)
501+
return A
502+
503+
import time
504+
505+
# Test with root node for descendant edges
506+
root_node = numba_ts.num_nodes - 1
507+
t = time.time()
508+
numba_desc = descendant_edges(numba_ts, root_node)
509+
print(f"Numba descendant edges time: {time.time() - t:.6f} seconds")
510+
511+
t = time.time()
512+
tskit_desc = descendant_edges_tskit(ts, root_node)
513+
print(f"tskit descendant edges time: {time.time() - t:.6f} seconds")
514+
515+
# Test with sample node for ancestral edges
516+
sample_node = 0
517+
t = time.time()
518+
numba_anc = ancestral_edges(numba_ts, sample_node)
519+
print(f"Numba ancestral edges time: {time.time() - t:.6f} seconds")
520+
521+
t = time.time()
522+
tskit_anc = ancestral_edges_tskit(ts, sample_node)
523+
print(f"tskit ancestral edges time: {time.time() - t:.6f} seconds")
524+
525+
# Verify results match
526+
np.testing.assert_array_equal(numba_desc, tskit_desc)
527+
np.testing.assert_array_equal(numba_anc, tskit_anc)
528+
print("Results match!")
529+
```
367530

368531
## API Reference
369532

python/tests/test_jit.py

Lines changed: 109 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -556,4 +556,112 @@ def descendant_span_tree(ts, u):
556556
for u in range(ts.num_nodes):
557557
d1 = descendant_span(numba_ts, u)
558558
d2 = descendant_span_tree(ts, u)
559-
np.testing.assert_array_almost_equal(d1, d2, decimal=10)
559+
nt.assert_array_almost_equal(d1, d2, decimal=10)
560+
561+
562+
@pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences())
563+
def test_jit_descendant_edges(ts):
564+
if ts.num_nodes == 0:
565+
pytest.skip("Tree sequence must have at least one node")
566+
567+
@numba.njit
568+
def descendant_edges(numba_ts, u):
569+
"""
570+
Returns a boolean mask for edges that are descendants of node u.
571+
"""
572+
edge_mask = np.zeros(numba_ts.num_edges, dtype=np.bool_)
573+
child_index = numba_ts.child_index()
574+
edges_left = numba_ts.edges_left
575+
edges_right = numba_ts.edges_right
576+
edges_child = numba_ts.edges_child
577+
578+
# The stack stores (node_id, left_coord, right_coord)
579+
stack = [(u, 0.0, numba_ts.sequence_length)]
580+
581+
while len(stack) > 0:
582+
node, left, right = stack.pop()
583+
584+
start, stop = child_index[node]
585+
for e in range(start, stop):
586+
e_left = edges_left[e]
587+
e_right = edges_right[e]
588+
589+
if e_right > left and right > e_left:
590+
edge_mask[e] = True
591+
inter_left = max(e_left, left)
592+
inter_right = min(e_right, right)
593+
e_child = edges_child[e]
594+
stack.append((e_child, inter_left, inter_right))
595+
596+
return edge_mask
597+
598+
def descendant_edges_tskit(ts, start_node):
599+
D = np.zeros(ts.num_edges, dtype=bool)
600+
for tree in ts.trees():
601+
for v in tree.preorder(start_node):
602+
# We want the edges *below* the start_node, so we skip the node itself.
603+
if v != start_node:
604+
D[tree.edge(v)] = True
605+
return D
606+
607+
numba_ts = jit_numba.jitwrap(ts)
608+
for u in range(ts.num_nodes):
609+
d1 = descendant_edges(numba_ts, u)
610+
d2 = descendant_edges_tskit(ts, u)
611+
nt.assert_array_equal(d1, d2)
612+
613+
614+
@pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences())
615+
def test_jit_ancestral_edges(ts):
616+
if ts.num_nodes == 0:
617+
pytest.skip("Tree sequence must have at least one node")
618+
619+
@numba.njit
620+
def ancestral_edges(numba_ts, u):
621+
"""
622+
Returns a boolean mask for edges that are ancestors of node u.
623+
"""
624+
edge_mask = np.zeros(numba_ts.num_edges, dtype=np.bool_)
625+
parent_index = numba_ts.parent_index()
626+
edges_left = numba_ts.edges_left
627+
edges_right = numba_ts.edges_right
628+
edges_parent = numba_ts.edges_parent
629+
630+
# The stack stores (node_id, left_coord, right_coord)
631+
stack = [(u, 0.0, numba_ts.sequence_length)]
632+
633+
while len(stack) > 0:
634+
node, left, right = stack.pop()
635+
636+
start, stop = parent_index.index_range[node]
637+
for i in range(start, stop):
638+
e = parent_index.edge_index[i]
639+
e_left = edges_left[e]
640+
e_right = edges_right[e]
641+
642+
if e_right > left and right > e_left:
643+
edge_mask[e] = True
644+
inter_left = max(e_left, left)
645+
inter_right = min(e_right, right)
646+
e_parent = edges_parent[e]
647+
stack.append((e_parent, inter_left, inter_right))
648+
649+
return edge_mask
650+
651+
def ancestral_edges_tskit(ts, start_node):
652+
A = np.zeros(ts.num_edges, dtype=bool)
653+
for tree in ts.trees():
654+
curr_node = start_node
655+
parent = tree.parent(curr_node)
656+
while parent != tskit.NULL:
657+
edge_id = tree.edge(curr_node)
658+
A[edge_id] = True
659+
curr_node = parent
660+
parent = tree.parent(curr_node)
661+
return A
662+
663+
numba_ts = jit_numba.jitwrap(ts)
664+
for u in range(ts.num_nodes):
665+
a1 = ancestral_edges(numba_ts, u)
666+
a2 = ancestral_edges_tskit(ts, u)
667+
nt.assert_array_equal(a1, a2)

0 commit comments

Comments
 (0)