Skip to content

Commit ce0d8ca

Browse files
committed
mask->select
1 parent 0470939 commit ce0d8ca

File tree

2 files changed

+25
-23
lines changed

2 files changed

+25
-23
lines changed

docs/numba.md

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -367,15 +367,15 @@ print("Results match!")
367367
### Example - ARG descendant and ancestral edges calculation
368368

369369
As we have `child_index` and `parent_index`, we can efficiently find both descendant and ancestral sub-ARGs
370-
for a given node. This first example shows how to find all edges in the ARG that are descendants of a given node. It returns a boolean mask indicating which edges are part of the sub-ARG rooted at the specified node:
370+
for a given node. This first example shows how to find all edges in the ARG that are descendants of a given node. It returns a boolean array indicating which edges are part of the sub-ARG rooted at the specified node:
371371

372372
```{code-cell} python
373373
@numba.njit
374374
def descendant_edges(numba_ts, u):
375375
"""
376-
Returns a boolean mask for edges that are descendants of node u.
376+
Returns a boolean array which is only True for edges that are descendants of node u.
377377
"""
378-
edge_mask = np.zeros(numba_ts.num_edges, dtype=np.bool_)
378+
edge_select = np.zeros(numba_ts.num_edges, dtype=np.bool_)
379379
child_index = numba_ts.child_index()
380380
edges_left = numba_ts.edges_left
381381
edges_right = numba_ts.edges_right
@@ -396,26 +396,26 @@ def descendant_edges(numba_ts, u):
396396
# Check for genomic interval overlap
397397
if e_right > left and right > e_left:
398398
# This edge is part of the sub-ARG
399-
edge_mask[e] = True
399+
edge_select[e] = True
400400
401401
# Calculate the intersection for the next traversal step
402402
inter_left = max(e_left, left)
403403
inter_right = min(e_right, right)
404404
e_child = edges_child[e]
405405
stack.append((e_child, inter_left, inter_right))
406406
407-
return edge_mask
407+
return edge_select
408408
```
409409

410410
```{code-cell} python
411411
# Find descendant edges for a high-numbered node (likely near root)
412412
test_node = max(0, numba_ts.num_nodes - 5)
413-
edge_mask = descendant_edges(numba_ts, test_node)
413+
edge_select = descendant_edges(numba_ts, test_node)
414414
415415
# Show which edges are descendants
416-
descendant_edge_ids = np.where(edge_mask)[0]
416+
descendant_edge_ids = np.where(edge_select)[0]
417417
print(f"Edges descended from node {test_node}: {descendant_edge_ids[:10]}...")
418-
print(f"Total descendant edges: {np.sum(edge_mask)}")
418+
print(f"Total descendant edges: {np.sum(edge_select)}")
419419
```
420420

421421
In the other direction, we can similarly find the sub-ARG that is ancestral to a given node:
@@ -424,9 +424,9 @@ In the other direction, we can similarly find the sub-ARG that is ancestral to a
424424
@numba.njit
425425
def ancestral_edges(numba_ts, u):
426426
"""
427-
Returns a boolean mask for edges that are ancestors of node u.
427+
Returns a boolean array which is only True for edges that are ancestors of node u.
428428
"""
429-
edge_mask = np.zeros(numba_ts.num_edges, dtype=np.bool_)
429+
edge_select = np.zeros(numba_ts.num_edges, dtype=np.bool_)
430430
parent_index = numba_ts.parent_index()
431431
edges_left = numba_ts.edges_left
432432
edges_right = numba_ts.edges_right
@@ -448,26 +448,26 @@ def ancestral_edges(numba_ts, u):
448448
# Check for genomic interval overlap
449449
if e_right > left and right > e_left:
450450
# This edge is part of the sub-ARG
451-
edge_mask[e] = True
451+
edge_select[e] = True
452452
453453
# Calculate the intersection for the next traversal step
454454
inter_left = max(e_left, left)
455455
inter_right = min(e_right, right)
456456
e_parent = edges_parent[e]
457457
stack.append((e_parent, inter_left, inter_right))
458458
459-
return edge_mask
459+
return edge_select
460460
```
461461

462462
```{code-cell} python
463463
# Find ancestral edges for a sample node (low-numbered nodes are usually samples)
464464
test_node = min(5, numba_ts.num_nodes - 1)
465-
edge_mask = ancestral_edges(numba_ts, test_node)
465+
edge_select = ancestral_edges(numba_ts, test_node)
466466
467467
# Show which edges are ancestors
468-
ancestral_edge_ids = np.where(edge_mask)[0]
468+
ancestral_edge_ids = np.where(edge_select)[0]
469469
print(f"Edges ancestral to node {test_node}: {ancestral_edge_ids[:10]}...")
470-
print(f"Total ancestral edges: {np.sum(edge_mask)}")
470+
print(f"Total ancestral edges: {np.sum(edge_select)}")
471471
```
472472

473473
```{code-cell} python

python/tests/test_jit.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -567,9 +567,10 @@ def test_jit_descendant_edges(ts):
567567
@numba.njit
568568
def descendant_edges(numba_ts, u):
569569
"""
570-
Returns a boolean mask for edges that are descendants of node u.
570+
Returns a boolean array which is only True for edges that
571+
are descendants of node u.
571572
"""
572-
edge_mask = np.zeros(numba_ts.num_edges, dtype=np.bool_)
573+
edge_select = np.zeros(numba_ts.num_edges, dtype=np.bool_)
573574
child_index = numba_ts.child_index()
574575
edges_left = numba_ts.edges_left
575576
edges_right = numba_ts.edges_right
@@ -587,13 +588,13 @@ def descendant_edges(numba_ts, u):
587588
e_right = edges_right[e]
588589

589590
if e_right > left and right > e_left:
590-
edge_mask[e] = True
591+
edge_select[e] = True
591592
inter_left = max(e_left, left)
592593
inter_right = min(e_right, right)
593594
e_child = edges_child[e]
594595
stack.append((e_child, inter_left, inter_right))
595596

596-
return edge_mask
597+
return edge_select
597598

598599
def descendant_edges_tskit(ts, start_node):
599600
D = np.zeros(ts.num_edges, dtype=bool)
@@ -619,9 +620,10 @@ def test_jit_ancestral_edges(ts):
619620
@numba.njit
620621
def ancestral_edges(numba_ts, u):
621622
"""
622-
Returns a boolean mask for edges that are ancestors of node u.
623+
Returns a boolean array which is only True for edges that are
624+
ancestors of node u.
623625
"""
624-
edge_mask = np.zeros(numba_ts.num_edges, dtype=np.bool_)
626+
edge_select = np.zeros(numba_ts.num_edges, dtype=np.bool_)
625627
parent_index = numba_ts.parent_index()
626628
edges_left = numba_ts.edges_left
627629
edges_right = numba_ts.edges_right
@@ -640,13 +642,13 @@ def ancestral_edges(numba_ts, u):
640642
e_right = edges_right[e]
641643

642644
if e_right > left and right > e_left:
643-
edge_mask[e] = True
645+
edge_select[e] = True
644646
inter_left = max(e_left, left)
645647
inter_right = min(e_right, right)
646648
e_parent = edges_parent[e]
647649
stack.append((e_parent, inter_left, inter_right))
648650

649-
return edge_mask
651+
return edge_select
650652

651653
def ancestral_edges_tskit(ts, start_node):
652654
A = np.zeros(ts.num_edges, dtype=bool)

0 commit comments

Comments
 (0)