@@ -364,6 +364,169 @@ np.testing.assert_array_almost_equal(numba_result, tskit_result, decimal=10)
364364print("Results match!")
365365```
366366
367+ ### Example - ARG descendant and ancestral edges calculation
368+
369+ As we have ` child_index ` and ` parent_index ` , we can efficiently find both descendant and ancestral sub-ARGs
370+ for a given node. This first example shows how to find all edges in the ARG that are descendants of a given node. It returns a boolean mask indicating which edges are part of the sub-ARG rooted at the specified node:
371+
372+ ``` {code-cell} python
373+ @numba.njit
374+ def descendant_edges(numba_ts, u):
375+ """
376+ Returns a boolean mask for edges that are descendants of node u.
377+ """
378+ edge_mask = np.zeros(numba_ts.num_edges, dtype=np.bool_)
379+ child_index = numba_ts.child_index()
380+ edges_left = numba_ts.edges_left
381+ edges_right = numba_ts.edges_right
382+ edges_child = numba_ts.edges_child
383+
384+ # The stack stores (node_id, left_coord, right_coord)
385+ stack = [(u, 0.0, numba_ts.sequence_length)]
386+
387+ while len(stack) > 0:
388+ node, left, right = stack.pop()
389+
390+ # Find all edges where 'node' is the parent
391+ start, stop = child_index[node]
392+ for e in range(start, stop):
393+ e_left = edges_left[e]
394+ e_right = edges_right[e]
395+
396+ # Check for genomic interval overlap
397+ if e_right > left and right > e_left:
398+ # This edge is part of the sub-ARG
399+ edge_mask[e] = True
400+
401+ # Calculate the intersection for the next traversal step
402+ inter_left = max(e_left, left)
403+ inter_right = min(e_right, right)
404+ e_child = edges_child[e]
405+ stack.append((e_child, inter_left, inter_right))
406+
407+ return edge_mask
408+ ```
409+
410+ ``` {code-cell} python
411+ # Find descendant edges for a high-numbered node (likely near root)
412+ test_node = max(0, numba_ts.num_nodes - 5)
413+ edge_mask = descendant_edges(numba_ts, test_node)
414+
415+ # Show which edges are descendants
416+ descendant_edge_ids = np.where(edge_mask)[0]
417+ print(f"Edges descended from node {test_node}: {descendant_edge_ids[:10]}...")
418+ print(f"Total descendant edges: {np.sum(edge_mask)}")
419+ ```
420+
421+ In the other direction, we can similarly find the sub-ARG that is ancestral to a given node:
422+
423+ ``` {code-cell} python
424+ @numba.njit
425+ def ancestral_edges(numba_ts, u):
426+ """
427+ Returns a boolean mask for edges that are ancestors of node u.
428+ """
429+ edge_mask = np.zeros(numba_ts.num_edges, dtype=np.bool_)
430+ parent_index = numba_ts.parent_index()
431+ edges_left = numba_ts.edges_left
432+ edges_right = numba_ts.edges_right
433+ edges_parent = numba_ts.edges_parent
434+
435+ # The stack stores (node_id, left_coord, right_coord)
436+ stack = [(u, 0.0, numba_ts.sequence_length)]
437+
438+ while len(stack) > 0:
439+ node, left, right = stack.pop()
440+
441+ # Find all edges where 'node' is the child
442+ start, stop = parent_index.index_range[node]
443+ for i in range(start, stop):
444+ e = parent_index.edge_index[i]
445+ e_left = edges_left[e]
446+ e_right = edges_right[e]
447+
448+ # Check for genomic interval overlap
449+ if e_right > left and right > e_left:
450+ # This edge is part of the sub-ARG
451+ edge_mask[e] = True
452+
453+ # Calculate the intersection for the next traversal step
454+ inter_left = max(e_left, left)
455+ inter_right = min(e_right, right)
456+ e_parent = edges_parent[e]
457+ stack.append((e_parent, inter_left, inter_right))
458+
459+ return edge_mask
460+ ```
461+
462+ ``` {code-cell} python
463+ # Find ancestral edges for a sample node (low-numbered nodes are usually samples)
464+ test_node = min(5, numba_ts.num_nodes - 1)
465+ edge_mask = ancestral_edges(numba_ts, test_node)
466+
467+ # Show which edges are ancestors
468+ ancestral_edge_ids = np.where(edge_mask)[0]
469+ print(f"Edges ancestral to node {test_node}: {ancestral_edge_ids[:10]}...")
470+ print(f"Total ancestral edges: {np.sum(edge_mask)}")
471+ ```
472+
473+ ``` {code-cell} python
474+ :tags: [hide-cell]
475+ # Warm up the JIT for both functions
476+ _ = descendant_edges(numba_ts, 0)
477+ _ = ancestral_edges(numba_ts, 0)
478+ ```
479+
480+ Comparing performance with using the tskit Python API shows significant speedup:
481+
482+ ``` {code-cell} python
483+ def descendant_edges_tskit(ts, start_node):
484+ D = np.zeros(ts.num_edges, dtype=bool)
485+ for tree in ts.trees():
486+ for v in tree.preorder(start_node):
487+ if v != start_node:
488+ D[tree.edge(v)] = True
489+ return D
490+
491+ def ancestral_edges_tskit(ts, start_node):
492+ A = np.zeros(ts.num_edges, dtype=bool)
493+ for tree in ts.trees():
494+ curr_node = start_node
495+ parent = tree.parent(curr_node)
496+ while parent != tskit.NULL:
497+ edge_id = tree.edge(curr_node)
498+ A[edge_id] = True
499+ curr_node = parent
500+ parent = tree.parent(curr_node)
501+ return A
502+
503+ import time
504+
505+ # Test with root node for descendant edges
506+ root_node = numba_ts.num_nodes - 1
507+ t = time.time()
508+ numba_desc = descendant_edges(numba_ts, root_node)
509+ print(f"Numba descendant edges time: {time.time() - t:.6f} seconds")
510+
511+ t = time.time()
512+ tskit_desc = descendant_edges_tskit(ts, root_node)
513+ print(f"tskit descendant edges time: {time.time() - t:.6f} seconds")
514+
515+ # Test with sample node for ancestral edges
516+ sample_node = 0
517+ t = time.time()
518+ numba_anc = ancestral_edges(numba_ts, sample_node)
519+ print(f"Numba ancestral edges time: {time.time() - t:.6f} seconds")
520+
521+ t = time.time()
522+ tskit_anc = ancestral_edges_tskit(ts, sample_node)
523+ print(f"tskit ancestral edges time: {time.time() - t:.6f} seconds")
524+
525+ # Verify results match
526+ np.testing.assert_array_equal(numba_desc, tskit_desc)
527+ np.testing.assert_array_equal(numba_anc, tskit_anc)
528+ print("Results match!")
529+ ```
367530
368531## API Reference
369532
0 commit comments