@@ -364,6 +364,169 @@ np.testing.assert_array_almost_equal(numba_result, tskit_result, decimal=10)
364
364
print("Results match!")
365
365
```
366
366
367
+ ### Example - ARG descendant and ancestral edges calculation
368
+
369
+ As we have ` child_index ` and ` parent_index ` , we can efficiently find both descendant and ancestral sub-ARGs
370
+ for a given node. This first example shows how to find all edges in the ARG that are descendants of a given node. It returns a boolean mask indicating which edges are part of the sub-ARG rooted at the specified node:
371
+
372
+ ``` {code-cell} python
373
+ @numba.njit
374
+ def descendant_edges(numba_ts, u):
375
+ """
376
+ Returns a boolean mask for edges that are descendants of node u.
377
+ """
378
+ edge_mask = np.zeros(numba_ts.num_edges, dtype=np.bool_)
379
+ child_index = numba_ts.child_index()
380
+ edges_left = numba_ts.edges_left
381
+ edges_right = numba_ts.edges_right
382
+ edges_child = numba_ts.edges_child
383
+
384
+ # The stack stores (node_id, left_coord, right_coord)
385
+ stack = [(u, 0.0, numba_ts.sequence_length)]
386
+
387
+ while len(stack) > 0:
388
+ node, left, right = stack.pop()
389
+
390
+ # Find all edges where 'node' is the parent
391
+ start, stop = child_index[node]
392
+ for e in range(start, stop):
393
+ e_left = edges_left[e]
394
+ e_right = edges_right[e]
395
+
396
+ # Check for genomic interval overlap
397
+ if e_right > left and right > e_left:
398
+ # This edge is part of the sub-ARG
399
+ edge_mask[e] = True
400
+
401
+ # Calculate the intersection for the next traversal step
402
+ inter_left = max(e_left, left)
403
+ inter_right = min(e_right, right)
404
+ e_child = edges_child[e]
405
+ stack.append((e_child, inter_left, inter_right))
406
+
407
+ return edge_mask
408
+ ```
409
+
410
+ ``` {code-cell} python
411
+ # Find descendant edges for a high-numbered node (likely near root)
412
+ test_node = max(0, numba_ts.num_nodes - 5)
413
+ edge_mask = descendant_edges(numba_ts, test_node)
414
+
415
+ # Show which edges are descendants
416
+ descendant_edge_ids = np.where(edge_mask)[0]
417
+ print(f"Edges descended from node {test_node}: {descendant_edge_ids[:10]}...")
418
+ print(f"Total descendant edges: {np.sum(edge_mask)}")
419
+ ```
420
+
421
+ In the other direction, we can similarly find the sub-ARG that is ancestral to a given node:
422
+
423
+ ``` {code-cell} python
424
+ @numba.njit
425
+ def ancestral_edges(numba_ts, u):
426
+ """
427
+ Returns a boolean mask for edges that are ancestors of node u.
428
+ """
429
+ edge_mask = np.zeros(numba_ts.num_edges, dtype=np.bool_)
430
+ parent_index = numba_ts.parent_index()
431
+ edges_left = numba_ts.edges_left
432
+ edges_right = numba_ts.edges_right
433
+ edges_parent = numba_ts.edges_parent
434
+
435
+ # The stack stores (node_id, left_coord, right_coord)
436
+ stack = [(u, 0.0, numba_ts.sequence_length)]
437
+
438
+ while len(stack) > 0:
439
+ node, left, right = stack.pop()
440
+
441
+ # Find all edges where 'node' is the child
442
+ start, stop = parent_index.index_range[node]
443
+ for i in range(start, stop):
444
+ e = parent_index.edge_index[i]
445
+ e_left = edges_left[e]
446
+ e_right = edges_right[e]
447
+
448
+ # Check for genomic interval overlap
449
+ if e_right > left and right > e_left:
450
+ # This edge is part of the sub-ARG
451
+ edge_mask[e] = True
452
+
453
+ # Calculate the intersection for the next traversal step
454
+ inter_left = max(e_left, left)
455
+ inter_right = min(e_right, right)
456
+ e_parent = edges_parent[e]
457
+ stack.append((e_parent, inter_left, inter_right))
458
+
459
+ return edge_mask
460
+ ```
461
+
462
+ ``` {code-cell} python
463
+ # Find ancestral edges for a sample node (low-numbered nodes are usually samples)
464
+ test_node = min(5, numba_ts.num_nodes - 1)
465
+ edge_mask = ancestral_edges(numba_ts, test_node)
466
+
467
+ # Show which edges are ancestors
468
+ ancestral_edge_ids = np.where(edge_mask)[0]
469
+ print(f"Edges ancestral to node {test_node}: {ancestral_edge_ids[:10]}...")
470
+ print(f"Total ancestral edges: {np.sum(edge_mask)}")
471
+ ```
472
+
473
+ ``` {code-cell} python
474
+ :tags: [hide-cell]
475
+ # Warm up the JIT for both functions
476
+ _ = descendant_edges(numba_ts, 0)
477
+ _ = ancestral_edges(numba_ts, 0)
478
+ ```
479
+
480
+ Comparing performance with using the tskit Python API shows significant speedup:
481
+
482
+ ``` {code-cell} python
483
+ def descendant_edges_tskit(ts, start_node):
484
+ D = np.zeros(ts.num_edges, dtype=bool)
485
+ for tree in ts.trees():
486
+ for v in tree.preorder(start_node):
487
+ if v != start_node:
488
+ D[tree.edge(v)] = True
489
+ return D
490
+
491
+ def ancestral_edges_tskit(ts, start_node):
492
+ A = np.zeros(ts.num_edges, dtype=bool)
493
+ for tree in ts.trees():
494
+ curr_node = start_node
495
+ parent = tree.parent(curr_node)
496
+ while parent != tskit.NULL:
497
+ edge_id = tree.edge(curr_node)
498
+ A[edge_id] = True
499
+ curr_node = parent
500
+ parent = tree.parent(curr_node)
501
+ return A
502
+
503
+ import time
504
+
505
+ # Test with root node for descendant edges
506
+ root_node = numba_ts.num_nodes - 1
507
+ t = time.time()
508
+ numba_desc = descendant_edges(numba_ts, root_node)
509
+ print(f"Numba descendant edges time: {time.time() - t:.6f} seconds")
510
+
511
+ t = time.time()
512
+ tskit_desc = descendant_edges_tskit(ts, root_node)
513
+ print(f"tskit descendant edges time: {time.time() - t:.6f} seconds")
514
+
515
+ # Test with sample node for ancestral edges
516
+ sample_node = 0
517
+ t = time.time()
518
+ numba_anc = ancestral_edges(numba_ts, sample_node)
519
+ print(f"Numba ancestral edges time: {time.time() - t:.6f} seconds")
520
+
521
+ t = time.time()
522
+ tskit_anc = ancestral_edges_tskit(ts, sample_node)
523
+ print(f"tskit ancestral edges time: {time.time() - t:.6f} seconds")
524
+
525
+ # Verify results match
526
+ np.testing.assert_array_equal(numba_desc, tskit_desc)
527
+ np.testing.assert_array_equal(numba_anc, tskit_anc)
528
+ print("Results match!")
529
+ ```
367
530
368
531
## API Reference
369
532
0 commit comments