|
| 1 | +--- |
| 2 | +jupytext: |
| 3 | + text_representation: |
| 4 | + extension: .md |
| 5 | + format_name: myst |
| 6 | + format_version: 0.12 |
| 7 | + jupytext_version: 1.9.1 |
| 8 | +kernelspec: |
| 9 | + display_name: Python 3 |
| 10 | + language: python |
| 11 | + name: python3 |
| 12 | +--- |
| 13 | + |
| 14 | +```{currentmodule} tskit.jit.numba |
| 15 | +``` |
| 16 | + |
| 17 | +(sec_numba)= |
| 18 | + |
| 19 | +# Numba Integration |
| 20 | + |
| 21 | +The `tskit.jit.numba` module provides classes for working with tree sequences |
| 22 | +from [Numba](https://numba.pydata.org/) jit-compiled Python code. Such code can run |
| 23 | +up to hundreds of times faster than normal Python, yet avoids the difficulties of writing |
| 24 | +C or other low-level code. |
| 25 | + |
| 26 | +:::{note} |
| 27 | +Numba is not a direct dependency of tskit, so will not be available unless installed: |
| 28 | + |
| 29 | +```bash |
| 30 | +pip install numba |
| 31 | +``` |
| 32 | + |
| 33 | +or |
| 34 | + |
| 35 | +```bash |
| 36 | +conda install numba |
| 37 | +``` |
| 38 | +::: |
| 39 | + |
| 40 | +## Overview |
| 41 | + |
| 42 | +The numba integration provides: |
| 43 | + |
| 44 | +- **{class}`NumbaTreeSequence`**: A Numba-compatible representation of tree sequence data |
| 45 | +- **{class}`NumbaTreeIndex`**: A class for efficient tree iteration |
| 46 | +- **{class}`NumbaEdgeRange`**: Container class for edge ranges during iteration |
| 47 | + |
| 48 | +These classes are designed to work within Numba's `@njit` decorated functions, |
| 49 | +allowing you to write high-performance tree sequence analysis code. |
| 50 | + |
| 51 | +## Basic Usage |
| 52 | + |
| 53 | +The ``tskit.jit.numba`` module is not imported with normal `tskit` so must be imported explicitly: |
| 54 | +```{code-cell} python |
| 55 | +import numpy as np |
| 56 | +import tskit |
| 57 | +import tskit.jit.numba as tskit_numba |
| 58 | +``` |
| 59 | + |
| 60 | +Normal third-party classes such as {class}`tskit.TreeSequence` can't be used in `numba.njit` compiled |
| 61 | +functions so the {class}`tskit.TreeSequence` must be wrapped in a {class}`NumbaTreeSequence` by |
| 62 | +{meth}`jitwrap`. This must be done outside `njit` code: |
| 63 | + |
| 64 | +```{code-cell} python |
| 65 | +import msprime |
| 66 | +
|
| 67 | +ts = msprime.sim_ancestry( |
| 68 | + samples=50000, |
| 69 | + sequence_length=100000, |
| 70 | + recombination_rate=0.1, |
| 71 | + random_seed=42 |
| 72 | +) |
| 73 | +numba_ts = tskit_numba.jitwrap(ts) |
| 74 | +print(type(numba_ts)) |
| 75 | +``` |
| 76 | + |
| 77 | +## Tree Iteration |
| 78 | + |
| 79 | +Tree iteration can be performed using the {class}`NumbaTreeIndex` class. |
| 80 | +This class provides `next()` and `prev()` methods for forward and backward iteration through the trees in a tree sequence. Its `in_range` and `out_range` attributes provide the edges that must be added or removed to form the current |
| 81 | +tree from the previous tree, along with the current tree `interval` and its sites and mutations through `site_range` and `mutation_range`. |
| 82 | + |
| 83 | +A `NumbaTreeIndex` instance can be obtained from a `NumbaTreeSequence` using the `tree_index()` method. The initial state of this is of a "null" tree outside the range of the tree sequence, the first call to `next()` or `prev()` will be to the first, or last tree sequence tree respectively. After that, the `in_range` and `out_range` attributes will provide the edges that must be added or removed to form the current tree from the previous tree. For example |
| 84 | +`tree_index.in_range.order[in_range.start:in_range.stop]` will give the edge ids that are new in the current tree, and `tree_index.out_range.order[out_range.start:out_range.stop]` will give the edge ids that are no longer present in the current tree. `tree_index.site_range` and |
| 85 | +`tree_index.mutation_range` give the indexes into the tree sequences site and mutation arrays. |
| 86 | + |
| 87 | +As a simple example we can calculate the number of edges in each tree in a tree sequence: |
| 88 | + |
| 89 | +```{code-cell} python |
| 90 | +import numba |
| 91 | +
|
| 92 | +@numba.njit |
| 93 | +def edges_per_tree(numba_ts): |
| 94 | + tree_index = numba_ts.tree_index() |
| 95 | + current_num_edges = 0 |
| 96 | + num_edges = [] |
| 97 | + |
| 98 | + # Move forward through the trees |
| 99 | + while tree_index.next(): |
| 100 | + # Access current tree information |
| 101 | + in_range = tree_index.in_range |
| 102 | + out_range = tree_index.out_range |
| 103 | + |
| 104 | + current_num_edges -= (out_range.stop - out_range.start) |
| 105 | + current_num_edges += (in_range.stop - in_range.start) |
| 106 | + num_edges.append(current_num_edges) |
| 107 | + return num_edges |
| 108 | +``` |
| 109 | + |
| 110 | +```{code-cell} python |
| 111 | +:tags: [hide-cell] |
| 112 | +# Warm up the JIT compiler |
| 113 | +edges = edges_per_tree(numba_ts) |
| 114 | +``` |
| 115 | + |
| 116 | + |
| 117 | +```{code-cell} python |
| 118 | +import time |
| 119 | +
|
| 120 | +t = time.time() |
| 121 | +jit_num_edges = edges_per_tree(numba_ts) |
| 122 | +print(f"JIT Time taken: {time.time() - t:.4f} seconds") |
| 123 | +``` |
| 124 | + |
| 125 | +Doing the same thing with the normal `tskit` API would be much slower: |
| 126 | + |
| 127 | +```{code-cell} python |
| 128 | +t = time.time() |
| 129 | +python_num_edges = [] |
| 130 | +for tree in ts.trees(): |
| 131 | + python_num_edges.append(tree.num_edges) |
| 132 | +print(f"Normal Time taken: {time.time() - t:.4f} seconds") |
| 133 | +
|
| 134 | +assert jit_num_edges == python_num_edges, "JIT and normal results do not match!" |
| 135 | +``` |
| 136 | + |
| 137 | +## Example - diversity calculation |
| 138 | + |
| 139 | +As a more interesting example we can calculate genetic diversity (also known as pi). |
| 140 | +For this example we'll be calculating based on the distance in the tree between samples. |
| 141 | +(`mode="branch"` in the tskit API.) |
| 142 | + |
| 143 | +This example also shows the style of Python code that gives best performance under `numba` |
| 144 | +JIT compilation - using simple loops and fixed-size arrays with minimal object attribute access. |
| 145 | + |
| 146 | +```{code-cell} python |
| 147 | + @numba.njit |
| 148 | + def diversity(numba_ts): |
| 149 | + # Cache arrays to avoid repeated attribute access in |
| 150 | + # tight loops |
| 151 | + edge_child = numba_ts.edges_child |
| 152 | + edge_parent = numba_ts.edges_parent |
| 153 | + node_times = numba_ts.nodes_time |
| 154 | + node_flags = numba_ts.nodes_flags |
| 155 | + |
| 156 | + if numba_ts.num_samples <= 1: |
| 157 | + return 0.0 |
| 158 | +
|
| 159 | + parent = np.full(numba_ts.num_nodes, -1, dtype=np.int32) |
| 160 | + branch_length = np.zeros(numba_ts.num_nodes, dtype=np.float64) |
| 161 | + state = np.zeros(numba_ts.num_nodes, dtype=np.int32) |
| 162 | + summary = np.zeros(numba_ts.num_nodes, dtype=np.float64) |
| 163 | +
|
| 164 | + n = float(numba_ts.num_samples) |
| 165 | + two_over_denom = 2.0 / (n * (n - 1.0)) |
| 166 | + sample_summary = 2.0 / n |
| 167 | +
|
| 168 | + # Retrieve this constant outside the loop |
| 169 | + # to avoid repeated attribute access |
| 170 | + NODE_IS_SAMPLE = tskit.NODE_IS_SAMPLE |
| 171 | + # Find the sample nodes and initialize their states |
| 172 | + for node in range(numba_ts.num_nodes): |
| 173 | + if node_flags[node] & NODE_IS_SAMPLE: |
| 174 | + state[node] = 1.0 |
| 175 | + summary[node] = sample_summary |
| 176 | +
|
| 177 | + result = 0.0 |
| 178 | + running_sum = 0.0 |
| 179 | + tree_index = numba_ts.tree_index() |
| 180 | +
|
| 181 | + # Now iterate through the trees |
| 182 | + while tree_index.next(): |
| 183 | + # Process the outgoing edges |
| 184 | + for j in range(tree_index.out_range.start, tree_index.out_range.stop): |
| 185 | + h = tree_index.out_range.order[j] |
| 186 | + u = edge_child[h] |
| 187 | +
|
| 188 | + running_sum -= branch_length[u] * summary[u] |
| 189 | + parent[u] = -1 |
| 190 | + branch_length[u] = 0.0 |
| 191 | +
|
| 192 | + u = edge_parent[h] |
| 193 | + while u != -1: |
| 194 | + running_sum -= branch_length[u] * summary[u] |
| 195 | + state[u] -= state[edge_child[h]] |
| 196 | + summary[u] = state[u] * (n - state[u]) * two_over_denom |
| 197 | + running_sum += branch_length[u] * summary[u] |
| 198 | + u = parent[u] |
| 199 | +
|
| 200 | + # Process the incoming edges |
| 201 | + for j in range(tree_index.in_range.start, tree_index.in_range.stop): |
| 202 | + h = tree_index.in_range.order[j] |
| 203 | + u = edge_child[h] |
| 204 | + v = edge_parent[h] |
| 205 | +
|
| 206 | + parent[u] = v |
| 207 | + branch_length[u] = node_times[v] - node_times[u] |
| 208 | + running_sum += branch_length[u] * summary[u] |
| 209 | +
|
| 210 | + u = v |
| 211 | + while u != -1: |
| 212 | + running_sum -= branch_length[u] * summary[u] |
| 213 | + state[u] += state[edge_child[h]] |
| 214 | + summary[u] = state[u] * (n - state[u]) * two_over_denom |
| 215 | + running_sum += branch_length[u] * summary[u] |
| 216 | + u = parent[u] |
| 217 | +
|
| 218 | + result += running_sum * ( |
| 219 | + tree_index.interval[1] - tree_index.interval[0] |
| 220 | + ) |
| 221 | +
|
| 222 | + return result / numba_ts.sequence_length |
| 223 | +``` |
| 224 | + |
| 225 | +```{code-cell} python |
| 226 | +:tags: [hide-cell] |
| 227 | +# Warm up the JIT |
| 228 | +d = diversity(numba_ts) |
| 229 | +``` |
| 230 | + |
| 231 | +```{code-cell} python |
| 232 | +t = time.time() |
| 233 | +d = diversity(numba_ts) |
| 234 | +print("Diversity:", d) |
| 235 | +print("Time taken:", time.time() - t) |
| 236 | +``` |
| 237 | + |
| 238 | +As this code is written for this specific diversity calculation it is even faster |
| 239 | +than the tskit C implementation, called here from Python: |
| 240 | + |
| 241 | +```{code-cell} python |
| 242 | +t = time.time() |
| 243 | +d_tskit = ts.diversity(mode="branch") |
| 244 | +print("Diversity (tskit):", d_tskit) |
| 245 | +print("Time taken:", time.time() - t) |
| 246 | +``` |
| 247 | + |
| 248 | + |
| 249 | + |
| 250 | + |
| 251 | +## API Reference |
| 252 | + |
| 253 | +```{eval-rst} |
| 254 | +.. currentmodule:: tskit.jit.numba |
| 255 | +
|
| 256 | +.. autofunction:: jitwrap |
| 257 | +
|
| 258 | +.. autoclass:: NumbaTreeSequence |
| 259 | + :members: |
| 260 | +
|
| 261 | +.. autoclass:: NumbaTreeIndex |
| 262 | + :members: |
| 263 | +
|
| 264 | +.. autoclass:: NumbaEdgeRange |
| 265 | + :members: |
| 266 | +``` |
0 commit comments