Skip to content

Commit 917c2fd

Browse files
benjefferyjeromekelleher
authored andcommitted
Simplfy and rename parent/child index API
1 parent 42554da commit 917c2fd

File tree

3 files changed

+126
-201
lines changed

3 files changed

+126
-201
lines changed

docs/numba.md

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,9 @@ conda install numba
4242
The numba integration provides:
4343

4444
- **{class}`NumbaTreeSequence`**: A Numba-compatible representation of tree sequence data
45-
- **{class}`NumbaTreeIndex`**: A class for efficient tree iteration
46-
- **{class}`NumbaEdgeRange`**: Container class for edge ranges during iteration
47-
- **{class}`NumbaChildIndex`**: A class for efficiently finding child edges of nodes
48-
- **{class}`NumbaParentIndex`**: A class for efficiently finding parent edges of nodes
45+
- **{class}`TreeIndex`**: A class for efficient tree iteration
46+
- **{class}`EdgeRange`**: Container class for edge ranges during iteration
47+
- **{class}`ParentIndex`**: Container for parent edge index information
4948

5049
These classes are designed to work within Numba's `@njit` decorated functions,
5150
allowing you to write high-performance tree sequence analysis code.
@@ -78,11 +77,11 @@ print(type(numba_ts))
7877

7978
## Tree Iteration
8079

81-
Tree iteration can be performed in `numba.njit` compiled functions using the {class}`NumbaTreeIndex` class.
80+
Tree iteration can be performed in `numba.njit` compiled functions using the {class}`TreeIndex` class.
8281
This class provides `next()` and `prev()` methods for forward and backward iteration through the trees in a tree sequence. Its `in_range` and `out_range` attributes provide the edges that must be added or removed to form the current
8382
tree from the previous tree, along with the current tree `interval` and its sites and mutations through `site_range` and `mutation_range`.
8483

85-
A `NumbaTreeIndex` instance can be obtained from a `NumbaTreeSequence` using the `tree_index()` method. The initial state of this is of a "null" tree outside the range of the tree sequence, the first call to `next()` or `prev()` will be to the first, or last tree sequence tree respectively. After that, the `in_range` and `out_range` attributes will provide the edges that must be added or removed to form the current tree from the previous tree. For example
84+
A `TreeIndex` instance can be obtained from a {class}`NumbaTreeSequence` using the {meth}`~NumbaTreeSequence.tree_index` method. The initial state of this is of a "null" tree outside the range of the tree sequence, the first call to `next()` or `prev()` will be to the first, or last tree sequence tree respectively. After that, the `in_range` and `out_range` attributes will provide the edges that must be added or removed to form the current tree from the previous tree. For example
8685
`tree_index.in_range.order[in_range.start:in_range.stop]` will give the edge ids that are new in the current tree, and `tree_index.out_range.order[out_range.start:out_range.stop]` will give the edge ids that are no longer present in the current tree. `tree_index.site_range` and
8786
`tree_index.mutation_range` give the indexes into the tree sequences site and mutation arrays.
8887

@@ -256,25 +255,25 @@ print("Time taken:", time.time() - t)
256255

257256
## ARG Traversal
258257

259-
Beyond iterating through trees, you may need to traverse the ARG vertically. The {class}`NumbaChildIndex` and {class}`NumbaParentIndex` classes provide efficient access to parent-child relationships in the edge table within `numba.njit` functions.
258+
Beyond iterating through trees, you may need to traverse the ARG vertically. The {meth}`~NumbaTreeSequence.child_index` method and {class}`ParentIndex` class provide efficient access to parent-child relationships in the edge table within `numba.njit` functions.
260259

261-
The {class}`NumbaChildIndex` allows you to efficiently find all edges where a given node is the parent. Since edges are already sorted by parent in the tskit data model, this is implemented using simple range indexing. For any node `u`, `child_range[u]` gives a tuple of the start and stop indices in the tskit edge table where node `u` is the parent.
260+
The {meth}`~NumbaTreeSequence.child_index` method returns an array that allows you to efficiently find all edges where a given node is the parent. Since edges are already sorted by parent in the tskit data model, this is implemented using simple range indexing. For any node `u`, the returned array `child_index[u]` gives a tuple of the start and stop indices in the tskit edge table where node `u` is the parent.
262261

263-
The {class}`NumbaParentIndex` allows you to efficiently find all edges where a given node is the child. Since edges are not sorted by child in the edge table, this class builds a custom index that sorts edge IDs by child node (and then by left coordinate). For any node `u`, `parent_range[u]` gives a tuple of the start and stop indices in the `parent_index` array, and `parent_index[start:stop]` gives the actual tskit edge IDs.
262+
The {meth}`~NumbaTreeSequence.parent_index` method creates a {class}`ParentIndex` that allows you to efficiently find all edges where a given node is the child. Since edges are not sorted by child in the edge table, the returned class contains a custom index that sorts edge IDs by child node (and then by left coordinate). For any node `u`, `parent_index.index_range[u]` gives a tuple of the start and stop indices in the `edge_index` array, and `parent_index.edge_index[start:stop]` gives the actual tskit edge IDs.
264263

265-
Both indexes can be obtained from a `NumbaTreeSequence`:
264+
Both can be obtained from a {class}`NumbaTreeSequence`:
266265

267266
```{code-cell} python
268267
# Get the indexes
269268
child_index = numba_ts.child_index()
270269
parent_index = numba_ts.parent_index()
271270
272271
# Example: find all edges where node 5 is the parent
273-
start, stop = child_index.child_range[5]
272+
start, stop = child_index[5]
274273
print(f"Node 5 has {stop - start} child edges")
275274
276275
# Example: find all edges where node 3 is the child
277-
start, stop = parent_index.parent_range[3]
276+
start, stop = parent_index.index_range[3]
278277
print(f"Node 3 appears as child in {stop - start} edges")
279278
```
280279

@@ -292,7 +291,6 @@ def descendant_span(numba_ts, u):
292291
descends from the specified node u.
293292
"""
294293
child_index = numba_ts.child_index()
295-
child_range = child_index.child_range
296294
edges_left = numba_ts.edges_left
297295
edges_right = numba_ts.edges_right
298296
edges_child = numba_ts.edges_child
@@ -308,7 +306,7 @@ def descendant_span(numba_ts, u):
308306
node, left, right = stack.pop()
309307
310308
# Find all child edges for this node
311-
for e in range(child_range[node, 0], child_range[node, 1]):
309+
for e in range(child_index[node, 0], child_index[node, 1]):
312310
e_left = edges_left[e]
313311
e_right = edges_right[e]
314312
@@ -377,15 +375,12 @@ print("Results match!")
377375
.. autoclass:: NumbaTreeSequence
378376
:members:
379377
380-
.. autoclass:: NumbaTreeIndex
378+
.. autoclass:: TreeIndex
381379
:members:
382380
383-
.. autoclass:: NumbaEdgeRange
381+
.. autoclass:: EdgeRange
384382
:members:
385383
386-
.. autoclass:: NumbaChildIndex
387-
:members:
388-
389-
.. autoclass:: NumbaParentIndex
384+
.. autoclass:: ParentIndex
390385
:members:
391386
```

python/tests/test_jit.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def test_child_index_correctness(ts):
122122
numba_ts = jit_numba.jitwrap(ts)
123123
child_index = numba_ts.child_index()
124124
for node in range(ts.num_nodes):
125-
start, stop = child_index.child_range[node]
125+
start, stop = child_index[node]
126126

127127
expected_children = []
128128
for edge_id in range(ts.num_edges):
@@ -144,7 +144,7 @@ def test_parent_index_correctness(ts):
144144
numba_ts = jit_numba.jitwrap(ts)
145145
parent_index = numba_ts.parent_index()
146146
for node in range(ts.num_nodes):
147-
start, stop = parent_index.parent_range[node]
147+
start, stop = parent_index.index_range[node]
148148

149149
expected_parents = []
150150
for edge_id in range(ts.num_edges):
@@ -157,7 +157,7 @@ def test_parent_index_correctness(ts):
157157
assert stop > start
158158
actual_parent_edge_ids = []
159159
for j in range(start, stop):
160-
edge_id = parent_index.parent_index[j]
160+
edge_id = parent_index.edge_index[j]
161161
actual_parent_edge_ids.append(edge_id)
162162
assert ts.edges_child[edge_id] == node
163163
assert set(actual_parent_edge_ids) == set(expected_parents)
@@ -173,10 +173,10 @@ def test_parent_index_tree_reconstruction(ts):
173173
position = tree.interval.left + 0.5 * tree.span
174174
reconstructed_parent = np.full(ts.num_nodes, -1, dtype=np.int32)
175175
for node in range(ts.num_nodes):
176-
start, stop = parent_index.parent_range[node]
176+
start, stop = parent_index.index_range[node]
177177
if start != -1:
178178
for j in range(start, stop):
179-
edge_id = parent_index.parent_index[j]
179+
edge_id = parent_index.edge_index[j]
180180
if ts.edges_left[edge_id] <= position < ts.edges_right[edge_id]:
181181
reconstructed_parent[node] = ts.edges_parent[edge_id]
182182
break
@@ -204,12 +204,12 @@ def _count_children_parents_numba(numba_ts):
204204

205205
for node in range(numba_ts.num_nodes):
206206
# Count child edges
207-
child_start, child_stop = child_index.child_range[node]
207+
child_start, child_stop = child_index[node]
208208
if child_start != -1:
209209
total_child_edges += child_stop - child_start
210210

211211
# Count parent edges
212-
parent_start, parent_stop = parent_index.parent_range[node]
212+
parent_start, parent_stop = parent_index.index_range[node]
213213
if parent_start != -1:
214214
total_parent_edges += parent_stop - parent_start
215215

@@ -435,7 +435,7 @@ def test_jitwrap_properties(ts):
435435
def test_numba_edge_range():
436436

437437
order = np.array([1, 3, 2, 0], dtype=np.int32)
438-
edge_range = jit_numba.NumbaEdgeRange(start=1, stop=3, order=order)
438+
edge_range = jit_numba.EdgeRange(start=1, stop=3, order=order)
439439

440440
assert edge_range.start == 1
441441
assert edge_range.stop == 3
@@ -520,7 +520,6 @@ def test_jit_descendant_span(ts):
520520
@numba.njit
521521
def descendant_span(numba_ts, u):
522522
child_index = numba_ts.child_index()
523-
child_range = child_index.child_range
524523
edges_left = numba_ts.edges_left
525524
edges_right = numba_ts.edges_right
526525
edges_child = numba_ts.edges_child
@@ -536,7 +535,7 @@ def descendant_span(numba_ts, u):
536535
node, left, right = stack.pop()
537536

538537
# Find all child edges for this node
539-
for e in range(child_range[node, 0], child_range[node, 1]):
538+
for e in range(child_index[node, 0], child_index[node, 1]):
540539
e_left = edges_left[e]
541540
e_right = edges_right[e]
542541

0 commit comments

Comments
 (0)