Skip to content

Commit 0b23799

Browse files
benjefferyjeromekelleher
authored andcommitted
Fix docs
1 parent 917c2fd commit 0b23799

File tree

2 files changed

+82
-74
lines changed

2 files changed

+82
-74
lines changed

docs/numba.md

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,18 @@ conda install numba
3939

4040
## Overview
4141

42-
The numba integration provides:
42+
The numba integration provides a {class}`tskit.TreeSequence` wrapper class {class}`NumbaTreeSequence`.
43+
This class can be used directly in `numba.njit` compiled functions, and provides several efficient
44+
methods for tree traversal:
4345

44-
- **{class}`NumbaTreeSequence`**: A Numba-compatible representation of tree sequence data
45-
- **{class}`TreeIndex`**: A class for efficient tree iteration
46-
- **{class}`EdgeRange`**: Container class for edge ranges during iteration
47-
- **{class}`ParentIndex`**: Container for parent edge index information
46+
- **{meth}`~NumbaTreeSequence.tree_index`**: For efficient iteration through the trees in the sequence
47+
- **{meth}`~NumbaTreeSequence.parent_index`**: For efficient access to parent edge information, to
48+
traverse upwards through the ARG.
49+
- **{meth}`~NumbaTreeSequence.child_index`**: For efficient access to child edge information, to
50+
traverse downwards through the ARG.
4851

49-
These classes are designed to work within Numba's `@njit` decorated functions,
50-
allowing you to write high-performance tree sequence analysis code.
52+
These methods are optimised to work within Numba's `@njit` decorated functions,
53+
allowing you to write high-performance tree sequence analysis code in a plain Python style.
5154

5255
## Basic Usage
5356

@@ -255,11 +258,11 @@ print("Time taken:", time.time() - t)
255258

256259
## ARG Traversal
257260

258-
Beyond iterating through trees, you may need to traverse the ARG vertically. The {meth}`~NumbaTreeSequence.child_index` method and {class}`ParentIndex` class provide efficient access to parent-child relationships in the edge table within `numba.njit` functions.
261+
Beyond iterating through trees, you may need to traverse the ARG vertically. The {meth}`~NumbaTreeSequence.child_index` and {meth}`~NumbaTreeSequence.parent_index` methods provide efficient access to parent-child relationships in the edge table within `numba.njit` functions.
259262

260-
The {meth}`~NumbaTreeSequence.child_index` method returns an array that allows you to efficiently find all edges where a given node is the parent. Since edges are already sorted by parent in the tskit data model, this is implemented using simple range indexing. For any node `u`, the returned array `child_index[u]` gives a tuple of the start and stop indices in the tskit edge table where node `u` is the parent.
263+
The {meth}`~NumbaTreeSequence.child_index` method returns an array that allows you to efficiently find all edges where a given node is the parent. Since edges are already sorted by parent in the tskit data model, this is implemented using simple range indexing. For any node `u`, the returned array `child_index[u]` gives a tuple of the start and stop indices in the tskit edge table where node `u` is the parent. The index is calculated on each call to `child_index()` so should be called once.
261264

262-
The {meth}`~NumbaTreeSequence.parent_index` method creates a {class}`ParentIndex` that allows you to efficiently find all edges where a given node is the child. Since edges are not sorted by child in the edge table, the returned class contains a custom index that sorts edge IDs by child node (and then by left coordinate). For any node `u`, `parent_index.index_range[u]` gives a tuple of the start and stop indices in the `edge_index` array, and `parent_index.edge_index[start:stop]` gives the actual tskit edge IDs.
265+
The {meth}`~NumbaTreeSequence.parent_index` method creates a {class}`ParentIndex` that allows you to efficiently find all edges where a given node is the child. Since edges are not sorted by child in the edge table, the returned class contains a custom index that sorts edge IDs by child node (and then by left coordinate). For any node `u`, `parent_index.index_range[u]` gives a tuple of the start and stop indices in the `parent_index.edge_index` array, and `parent_index.edge_index[start:stop]` gives the actual tskit edge IDs.
263266

264267
Both can be obtained from a {class}`NumbaTreeSequence`:
265268

@@ -268,20 +271,22 @@ Both can be obtained from a {class}`NumbaTreeSequence`:
268271
child_index = numba_ts.child_index()
269272
parent_index = numba_ts.parent_index()
270273
271-
# Example: find all edges where node 5 is the parent
274+
# Example: find all left coordinates of edges where node 5 is the parent
272275
start, stop = child_index[5]
273-
print(f"Node 5 has {stop - start} child edges")
276+
left_coords = numba_ts.edges_left[start:stop]
277+
print(left_coords)
274278
275-
# Example: find all edges where node 3 is the child
279+
# Example: find all right coordinates of edges where node 3 is the child
276280
start, stop = parent_index.index_range[3]
277-
print(f"Node 3 appears as child in {stop - start} edges")
281+
right_coords = numba_ts.edges_right[start:stop]
282+
print(right_coords)
278283
```
279284

280285
These indexes enable efficient algorithms that need to traverse parent-child relationships in the ARG, such as computing descendant sets, ancestral paths, or subtree properties.
281286

282287
### Example - descendant span calculation
283288

284-
Here's an example of using the ARG traversal classes to calculate the total sequence length over which each node descends from a specified node:
289+
Here's an example of using the ARG traversal indexes to calculate the total sequence length over which each node descends from a specified node:
285290

286291
```{code-cell} python
287292
@numba.njit

python/tskit/jit/numba.py

Lines changed: 62 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,12 @@ class EdgeRange:
4646
4747
Attributes
4848
----------
49-
start : int32
49+
start : int
5050
Starting index of the edge range (inclusive).
51-
stop : int32
51+
stop : int
5252
Stopping index of the edge range (exclusive).
53-
order : int32[]
54-
Array containing edge IDs in the order they should be processed.
53+
order : numpy.ndarray
54+
Array (dtype=np.int32) containing edge IDs in the order they should be processed.
5555
The edge ids in this range are order[start:stop].
5656
"""
5757

@@ -73,10 +73,11 @@ class ParentIndex:
7373
7474
Attributes
7575
----------
76-
edge_index : int32[num_edges]
77-
Array of edge IDs sorted by child node and left coordinate.
78-
index_range : int32[num_nodes, 2]
79-
For each node, the [start, stop) range in edge_index where this node is child.
76+
edge_index : numpy.ndarray
77+
Array (dtype=np.int32) of edge IDs sorted by child node and left coordinate.
78+
index_range : numpy.ndarray
79+
Array (dtype=np.int32, shape=(num_nodes, 2)) where each row contains the
80+
[start, stop) range in edge_index where this node is the child.
8081
"""
8182

8283
def __init__(self, edge_index, index_range):
@@ -100,20 +101,20 @@ class TreeIndex:
100101
----------
101102
ts : NumbaTreeSequence
102103
Reference to the tree sequence being traversed.
103-
index : int32
104+
index : int
104105
Current tree index. -1 indicates no current tree (null state).
105-
direction : int32
106+
direction : int
106107
Traversal direction: tskit.FORWARD or tskit.REVERSE. tskit.NULL
107108
if uninitialised.
108-
interval : tuple of float64
109+
interval : tuple
109110
Genomic interval (left, right) covered by the current tree.
110-
in_range : NumbaEdgeRange
111+
in_range : EdgeRange
111112
Edges being added to form this current tree, relative to the last state
112-
out_range : NumbaEdgeRange
113+
out_range : EdgeRange
113114
Edges being removed to form this current tree, relative to the last state
114-
site_range : tuple of int32
115+
site_range : tuple
115116
Range of sites in the current tree (start, stop).
116-
mutation_range : tuple of int32
117+
mutation_range : tuple
117118
Range of mutations in the current tree (start, stop).
118119
119120
Example
@@ -319,54 +320,56 @@ class NumbaTreeSequence:
319320
320321
Attributes
321322
----------
322-
num_trees : int32
323+
num_trees : int
323324
Number of trees in the tree sequence.
324-
num_nodes : int32
325+
num_nodes : int
325326
Number of nodes in the tree sequence.
326-
num_samples : int32
327+
num_samples : int
327328
Number of samples in the tree sequence.
328-
num_edges : int32
329+
num_edges : int
329330
Number of edges in the tree sequence.
330-
num_sites : int32
331+
num_sites : int
331332
Number of sites in the tree sequence.
332-
num_mutations : int32
333+
num_mutations : int
333334
Number of mutations in the tree sequence.
334-
sequence_length : float64
335+
sequence_length : float
335336
Total sequence length of the tree sequence.
336-
edges_left : float64[]
337-
Left coordinates of edges.
338-
edges_right : float64[]
339-
Right coordinates of edges.
340-
edges_parent : int32[]
341-
Parent node IDs for each edge.
342-
edges_child : int32[]
343-
Child node IDs for each edge.
344-
nodes_time : float64[]
345-
Time values for each node.
346-
nodes_flags : uint32[]
347-
Flag values for each node.
348-
nodes_population : int32[]
349-
Population IDs for each node.
350-
nodes_individual : int32[]
351-
Individual IDs for each node.
352-
individuals_flags : uint32[]
353-
Flag values for each individual.
354-
sites_position : float64[]
355-
Positions of sites along the sequence.
356-
mutations_site : int32[]
357-
Site IDs for each mutation.
358-
mutations_node : int32[]
359-
Node IDs for each mutation.
360-
mutations_parent : int32[]
361-
Parent mutation IDs.
362-
mutations_time : float64[]
363-
Time values for each mutation.
364-
breakpoints : float64[]
365-
Genomic positions where trees change.
366-
indexes_edge_insertion_order : int32[]
367-
Order in which edges are inserted during tree building.
368-
indexes_edge_removal_order : int32[]
369-
Order in which edges are removed during tree building.
337+
edges_left : numpy.ndarray
338+
Array (dtype=np.float64) of left coordinates of edges.
339+
edges_right : numpy.ndarray
340+
Array (dtype=np.float64) of right coordinates of edges.
341+
edges_parent : numpy.ndarray
342+
Array (dtype=np.int32) of parent node IDs for each edge.
343+
edges_child : numpy.ndarray
344+
Array (dtype=np.int32) of child node IDs for each edge.
345+
nodes_time : numpy.ndarray
346+
Array (dtype=np.float64) of time values for each node.
347+
nodes_flags : numpy.ndarray
348+
Array (dtype=np.uint32) of flag values for each node.
349+
nodes_population : numpy.ndarray
350+
Array (dtype=np.int32) of population IDs for each node.
351+
nodes_individual : numpy.ndarray
352+
Array (dtype=np.int32) of individual IDs for each node.
353+
individuals_flags : numpy.ndarray
354+
Array (dtype=np.uint32) of flag values for each individual.
355+
sites_position : numpy.ndarray
356+
Array (dtype=np.float64) of positions of sites along the sequence.
357+
mutations_site : numpy.ndarray
358+
Array (dtype=np.int32) of site IDs for each mutation.
359+
mutations_node : numpy.ndarray
360+
Array (dtype=np.int32) of node IDs for each mutation.
361+
mutations_parent : numpy.ndarray
362+
Array (dtype=np.int32) of parent mutation IDs.
363+
mutations_time : numpy.ndarray
364+
Array (dtype=np.float64) of time values for each mutation.
365+
breakpoints : numpy.ndarray
366+
Array (dtype=np.float64) of genomic positions where trees change.
367+
indexes_edge_insertion_order : numpy.ndarray
368+
Array (dtype=np.int32) specifying the order in which edges are inserted
369+
during tree building.
370+
indexes_edge_removal_order : numpy.ndarray
371+
Array (dtype=np.int32) specifying the order in which edges are removed
372+
during tree building.
370373
371374
"""
372375

@@ -446,9 +449,9 @@ def child_index(self):
446449
"""
447450
Create child index array for finding child edges of nodes.
448451
449-
:return: Array where each row [node] contains [start, stop) range of edges
450-
where this node is the parent.
451-
:rtype: int32[num_nodes, 2]
452+
:return: A numpy array (dtype=np.int32, shape=(num_nodes, 2)) where each row
453+
contains the [start, stop) range of edges where this node is the parent.
454+
:rtype: numpy.ndarray
452455
"""
453456
child_range = np.full((self.num_nodes, 2), -1, dtype=np.int32)
454457
edges_parent = self.edges_parent

0 commit comments

Comments
 (0)