Skip to content

Commit 42554da

Browse files
benjefferyjeromekelleher
authored andcommitted
Initial ARG traversal indexes for numba
1 parent 2a4ec4f commit 42554da

File tree

3 files changed

+492
-52
lines changed

3 files changed

+492
-52
lines changed

docs/numba.md

Lines changed: 120 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ The numba integration provides:
4444
- **{class}`NumbaTreeSequence`**: A Numba-compatible representation of tree sequence data
4545
- **{class}`NumbaTreeIndex`**: A class for efficient tree iteration
4646
- **{class}`NumbaEdgeRange`**: Container class for edge ranges during iteration
47+
- **{class}`NumbaChildIndex`**: A class for efficiently finding child edges of nodes
48+
- **{class}`NumbaParentIndex`**: A class for efficiently finding parent edges of nodes
4749

4850
These classes are designed to work within Numba's `@njit` decorated functions,
4951
allowing you to write high-performance tree sequence analysis code.
@@ -76,7 +78,7 @@ print(type(numba_ts))
7678

7779
## Tree Iteration
7880

79-
Tree iteration can be performed using the {class}`NumbaTreeIndex` class.
81+
Tree iteration can be performed in `numba.njit` compiled functions using the {class}`NumbaTreeIndex` class.
8082
This class provides `next()` and `prev()` methods for forward and backward iteration through the trees in a tree sequence. Its `in_range` and `out_range` attributes provide the edges that must be added or removed to form the current
8183
tree from the previous tree, along with the current tree `interval` and its sites and mutations through `site_range` and `mutation_range`.
8284

@@ -134,7 +136,7 @@ print(f"Normal Time taken: {time.time() - t:.4f} seconds")
134136
assert jit_num_edges == python_num_edges, "JIT and normal results do not match!"
135137
```
136138

137-
## Example - diversity calculation
139+
### Example - diversity calculation
138140

139141
As a more interesting example we can calculate genetic diversity (also known as pi).
140142
For this example we'll be calculating based on the distance in the tree between samples.
@@ -252,7 +254,117 @@ print("Diversity (tskit):", d_tskit)
252254
print("Time taken:", time.time() - t)
253255
```
254256

257+
## ARG Traversal
255258

259+
Beyond iterating through trees, you may need to traverse the ARG vertically. The {class}`NumbaChildIndex` and {class}`NumbaParentIndex` classes provide efficient access to parent-child relationships in the edge table within `numba.njit` functions.
260+
261+
The {class}`NumbaChildIndex` allows you to efficiently find all edges where a given node is the parent. Since edges are already sorted by parent in the tskit data model, this is implemented using simple range indexing. For any node `u`, `child_range[u]` gives a tuple of the start and stop indices in the tskit edge table where node `u` is the parent.
262+
263+
The {class}`NumbaParentIndex` allows you to efficiently find all edges where a given node is the child. Since edges are not sorted by child in the edge table, this class builds a custom index that sorts edge IDs by child node (and then by left coordinate). For any node `u`, `parent_range[u]` gives a tuple of the start and stop indices in the `parent_index` array, and `parent_index[start:stop]` gives the actual tskit edge IDs.
264+
265+
Both indexes can be obtained from a `NumbaTreeSequence`:
266+
267+
```{code-cell} python
268+
# Get the indexes
269+
child_index = numba_ts.child_index()
270+
parent_index = numba_ts.parent_index()
271+
272+
# Example: find all edges where node 5 is the parent
273+
start, stop = child_index.child_range[5]
274+
print(f"Node 5 has {stop - start} child edges")
275+
276+
# Example: find all edges where node 3 is the child
277+
start, stop = parent_index.parent_range[3]
278+
print(f"Node 3 appears as child in {stop - start} edges")
279+
```
280+
281+
These indexes enable efficient algorithms that need to traverse parent-child relationships in the ARG, such as computing descendant sets, ancestral paths, or subtree properties.
282+
283+
### Example - descendant span calculation
284+
285+
Here's an example of using the ARG traversal classes to calculate the total sequence length over which each node descends from a specified node:
286+
287+
```{code-cell} python
288+
@numba.njit
289+
def descendant_span(numba_ts, u):
290+
"""
291+
Calculate the total sequence length over which each node
292+
descends from the specified node u.
293+
"""
294+
child_index = numba_ts.child_index()
295+
child_range = child_index.child_range
296+
edges_left = numba_ts.edges_left
297+
edges_right = numba_ts.edges_right
298+
edges_child = numba_ts.edges_child
299+
300+
total_descending = np.zeros(numba_ts.num_nodes)
301+
stack = [(u, 0.0, numba_ts.sequence_length)]
302+
303+
# TODO is it right that u is considered to inherit from itself
304+
# across the whole sequence?
305+
total_descending[u] = numba_ts.sequence_length
306+
307+
while len(stack) > 0:
308+
node, left, right = stack.pop()
309+
310+
# Find all child edges for this node
311+
for e in range(child_range[node, 0], child_range[node, 1]):
312+
e_left = edges_left[e]
313+
e_right = edges_right[e]
314+
315+
# Check if edge overlaps with current interval
316+
if e_right > left and right > e_left:
317+
inter_left = max(e_left, left)
318+
inter_right = min(e_right, right)
319+
e_child = edges_child[e]
320+
321+
total_descending[e_child] += inter_right - inter_left
322+
stack.append((e_child, inter_left, inter_right))
323+
324+
return total_descending
325+
```
326+
327+
```{code-cell} python
328+
:tags: [hide-cell]
329+
# Warm up the JIT
330+
result = descendant_span(numba_ts, 0)
331+
```
332+
333+
```{code-cell} python
334+
# Calculate descendant span for the root node (highest numbered node)
335+
root_node = numba_ts.num_nodes - 1
336+
result = descendant_span(numba_ts, root_node)
337+
338+
# Show nodes that have non-zero descendant span
339+
non_zero = result > 0
340+
print(f"Nodes descended from {root_node}:")
341+
print(f"Node IDs: {np.where(non_zero)[0]}")
342+
print(f"Span lengths: {result[non_zero]}")
343+
```
344+
345+
Comparing performance with using the tskit Python API:
346+
347+
```{code-cell} python
348+
def descendant_span_tskit(ts, u):
349+
"""Reference implementation using tskit trees"""
350+
total_descending = np.zeros(ts.num_nodes)
351+
for tree in ts.trees():
352+
descendants = tree.preorder(u)
353+
total_descending[descendants] += tree.span
354+
return total_descending
355+
356+
import time
357+
t = time.time()
358+
numba_result = descendant_span(numba_ts, root_node)
359+
print(f"Numba time: {time.time() - t:.6f} seconds")
360+
361+
t = time.time()
362+
tskit_result = descendant_span_tskit(ts, root_node)
363+
print(f"tskit time: {time.time() - t:.6f} seconds")
364+
365+
np.testing.assert_array_almost_equal(numba_result, tskit_result, decimal=10)
366+
print("Results match!")
367+
```
256368

257369

258370
## API Reference
@@ -270,4 +382,10 @@ print("Time taken:", time.time() - t)
270382
271383
.. autoclass:: NumbaEdgeRange
272384
:members:
385+
386+
.. autoclass:: NumbaChildIndex
387+
:members:
388+
389+
.. autoclass:: NumbaParentIndex
390+
:members:
273391
```

python/tests/test_jit.py

Lines changed: 173 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,126 @@ def test_correct_trees_backwards_and_forwards(ts):
117117
assert last_tree
118118

119119

120-
def test_using_from_jit_function():
120+
@pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences())
121+
def test_child_index_correctness(ts):
122+
numba_ts = jit_numba.jitwrap(ts)
123+
child_index = numba_ts.child_index()
124+
for node in range(ts.num_nodes):
125+
start, stop = child_index.child_range[node]
126+
127+
expected_children = []
128+
for edge_id in range(ts.num_edges):
129+
if ts.edges_parent[edge_id] == node:
130+
expected_children.append(edge_id)
131+
132+
if len(expected_children) == 0:
133+
assert start == -1 and stop == -1
134+
else:
135+
assert stop > start
136+
actual_children = list(range(start, stop))
137+
for edge_id in actual_children:
138+
assert ts.edges_parent[edge_id] == node
139+
assert actual_children == expected_children
140+
141+
142+
@pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences())
143+
def test_parent_index_correctness(ts):
144+
numba_ts = jit_numba.jitwrap(ts)
145+
parent_index = numba_ts.parent_index()
146+
for node in range(ts.num_nodes):
147+
start, stop = parent_index.parent_range[node]
148+
149+
expected_parents = []
150+
for edge_id in range(ts.num_edges):
151+
if ts.edges_child[edge_id] == node:
152+
expected_parents.append(edge_id)
153+
154+
if len(expected_parents) == 0:
155+
assert start == -1 and stop == -1
156+
else:
157+
assert stop > start
158+
actual_parent_edge_ids = []
159+
for j in range(start, stop):
160+
edge_id = parent_index.parent_index[j]
161+
actual_parent_edge_ids.append(edge_id)
162+
assert ts.edges_child[edge_id] == node
163+
assert set(actual_parent_edge_ids) == set(expected_parents)
164+
165+
166+
@pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences())
167+
def test_parent_index_tree_reconstruction(ts):
168+
numba_ts = jit_numba.jitwrap(ts)
169+
parent_index = numba_ts.parent_index()
170+
171+
# Test tree reconstruction at all breakpoints
172+
for tree in ts.trees():
173+
position = tree.interval.left + 0.5 * tree.span
174+
reconstructed_parent = np.full(ts.num_nodes, -1, dtype=np.int32)
175+
for node in range(ts.num_nodes):
176+
start, stop = parent_index.parent_range[node]
177+
if start != -1:
178+
for j in range(start, stop):
179+
edge_id = parent_index.parent_index[j]
180+
if ts.edges_left[edge_id] <= position < ts.edges_right[edge_id]:
181+
reconstructed_parent[node] = ts.edges_parent[edge_id]
182+
break
183+
expected_parent = tree.parent_array
184+
185+
# Compare parent arrays (excluding virtual root)
186+
nt.assert_array_equal(
187+
reconstructed_parent,
188+
expected_parent[:-1],
189+
)
190+
191+
192+
def test_child_parent_index_from_jit_function():
193+
ts = msprime.sim_ancestry(
194+
samples=10, sequence_length=100, recombination_rate=1, random_seed=42
195+
)
196+
197+
@numba.njit
198+
def _count_children_parents_numba(numba_ts):
199+
child_index = numba_ts.child_index()
200+
parent_index = numba_ts.parent_index()
201+
202+
total_child_edges = 0
203+
total_parent_edges = 0
204+
205+
for node in range(numba_ts.num_nodes):
206+
# Count child edges
207+
child_start, child_stop = child_index.child_range[node]
208+
if child_start != -1:
209+
total_child_edges += child_stop - child_start
210+
211+
# Count parent edges
212+
parent_start, parent_stop = parent_index.parent_range[node]
213+
if parent_start != -1:
214+
total_parent_edges += parent_stop - parent_start
215+
216+
return total_child_edges, total_parent_edges
217+
218+
def count_children_parents_python(ts):
219+
total_child_edges = 0
220+
total_parent_edges = 0
221+
222+
for node in range(ts.num_nodes):
223+
# Count child edges
224+
for edge in ts.edges():
225+
if edge.parent == node:
226+
total_child_edges += 1
227+
if edge.child == node:
228+
total_parent_edges += 1
229+
230+
return total_child_edges, total_parent_edges
231+
232+
numba_ts = jit_numba.jitwrap(ts)
233+
numba_result = _count_children_parents_numba(numba_ts)
234+
python_result = count_children_parents_python(ts)
235+
236+
assert numba_result == python_result
237+
238+
239+
def test_using_tree_index_from_jit_function():
121240
# Test we can use from a numba jitted function
122241

123242
ts = msprime.sim_ancestry(
@@ -391,3 +510,56 @@ def test_numba_tree_index_edge_cases():
391510
assert tree_index.index == 0
392511
assert not tree_index.next() # No more trees
393512
assert tree_index.index == -1
513+
514+
515+
@pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences())
516+
def test_jit_descendant_span(ts):
517+
if ts.num_nodes == 0:
518+
pytest.skip("Tree sequence must have at least one node")
519+
520+
@numba.njit
521+
def descendant_span(numba_ts, u):
522+
child_index = numba_ts.child_index()
523+
child_range = child_index.child_range
524+
edges_left = numba_ts.edges_left
525+
edges_right = numba_ts.edges_right
526+
edges_child = numba_ts.edges_child
527+
528+
total_descending = np.zeros(numba_ts.num_nodes)
529+
stack = [(u, 0.0, numba_ts.sequence_length)]
530+
531+
# TODO is it right that u is considered to inherit from itself
532+
# across the whole sequence?
533+
total_descending[u] = numba_ts.sequence_length
534+
535+
while len(stack) > 0:
536+
node, left, right = stack.pop()
537+
538+
# Find all child edges for this node
539+
for e in range(child_range[node, 0], child_range[node, 1]):
540+
e_left = edges_left[e]
541+
e_right = edges_right[e]
542+
543+
# Check if edge overlaps with current interval
544+
if e_right > left and right > e_left:
545+
inter_left = max(e_left, left)
546+
inter_right = min(e_right, right)
547+
e_child = edges_child[e]
548+
549+
total_descending[e_child] += inter_right - inter_left
550+
stack.append((e_child, inter_left, inter_right))
551+
552+
return total_descending
553+
554+
def descendant_span_tree(ts, u):
555+
total_descending = np.zeros(ts.num_nodes)
556+
for tree in ts.trees():
557+
descendants = tree.preorder(u)
558+
total_descending[descendants] += tree.span
559+
return total_descending
560+
561+
numba_ts = jit_numba.jitwrap(ts)
562+
for u in range(ts.num_nodes):
563+
d1 = descendant_span(numba_ts, u)
564+
d2 = descendant_span_tree(ts, u)
565+
np.testing.assert_array_almost_equal(d1, d2, decimal=10)

0 commit comments

Comments
 (0)