Skip to content

Commit d1eb1b4

Browse files
committed
Add initial numba support
1 parent 0ea8e64 commit d1eb1b4

File tree

9 files changed

+1248
-1
lines changed

9 files changed

+1248
-1
lines changed

.github/workflows/tests.yml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,22 @@ jobs:
115115
conda activate anaconda-client-env
116116
python -c "import tskit;tskit.Tree.generate_star(5).tree_sequence.draw_svg(path='test.svg')"
117117
118+
- name: Run JIT code coverage
119+
working-directory: python
120+
run: |
121+
source ~/.profile
122+
conda activate anaconda-client-env
123+
NUMBA_DISABLE_JIT=1 python -m pytest -x --cov=tskit --cov-report=xml --cov-branch -n2 tests/test_jit.py
124+
125+
- name: Upload coverage to Codecov
126+
uses: codecov/[email protected]
127+
with:
128+
token: ${{ secrets.CODECOV_TOKEN }}
129+
working-directory: python
130+
fail_ci_if_error: false
131+
flags: python-tests-no-jit
132+
name: codecov-umbrella
133+
verbose: true
118134
- name: Run tests
119135
working-directory: python
120136
run: |
@@ -136,6 +152,7 @@ jobs:
136152
name: codecov-umbrella
137153
verbose: true
138154

155+
139156
test-numpy1:
140157
name: Numpy 1.x
141158
runs-on: ubuntu-24.04

docs/_toc.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ parts:
2020
- caption: Interfaces
2121
chapters:
2222
- file: python-api
23+
- file: numba
2324
- file: c-api
2425
- file: cli
2526
- file: file-formats

docs/numba.md

Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
---
2+
jupytext:
3+
text_representation:
4+
extension: .md
5+
format_name: myst
6+
format_version: 0.12
7+
jupytext_version: 1.9.1
8+
kernelspec:
9+
display_name: Python 3
10+
language: python
11+
name: python3
12+
---
13+
14+
```{currentmodule} tskit.jit.numba
15+
```
16+
17+
(sec_numba)=
18+
19+
# Numba Integration
20+
21+
The `tskit.jit.numba` module provides classes for working with tree sequences
22+
from [Numba](https://numba.pydata.org/) jit-compiled Python code. Such code can run
23+
up to hundreds of times faster than normal Python, yet avoids the difficulties of writing
24+
C or other low-level code.
25+
26+
:::{note}
27+
Numba is not a direct dependency of tskit, so will not be available unless installed:
28+
29+
```bash
30+
pip install numba
31+
```
32+
33+
or
34+
35+
```bash
36+
conda install numba
37+
```
38+
:::
39+
40+
## Overview
41+
42+
The numba integration provides:
43+
44+
- **{class}`NumbaTreeSequence`**: A Numba-compatible representation of tree sequence data
45+
- **{class}`NumbaTreeIndex`**: A class for efficient tree iteration
46+
- **{class}`NumbaEdgeRange`**: Container class for edge ranges during iteration
47+
48+
These classes are designed to work within Numba's `@njit` decorated functions,
49+
allowing you to write high-performance tree sequence analysis code.
50+
51+
## Basic Usage
52+
53+
The ``tskit.jit.numba`` module is not imported with normal `tskit` so must be imported explicitly:
54+
```{code-cell} python
55+
import numpy as np
56+
import tskit
57+
import tskit.jit.numba as tskit_numba
58+
```
59+
60+
Normal third-party classes such as {class}`tskit.TreeSequence` can't be used in `numba.njit` compiled
61+
functions so the {class}`tskit.TreeSequence` must be wrapped in a {class}`NumbaTreeSequence` by
62+
{meth}`jitwrap`. This must be done outside `njit` code:
63+
64+
```{code-cell} python
65+
import msprime
66+
67+
ts = msprime.sim_ancestry(
68+
samples=50000,
69+
sequence_length=100000,
70+
recombination_rate=0.1,
71+
random_seed=42
72+
)
73+
numba_ts = tskit_numba.jitwrap(ts)
74+
print(type(numba_ts))
75+
```
76+
77+
## Tree Iteration
78+
79+
Tree iteration can be performed using the {class}`NumbaTreeIndex` class.
80+
This class provides `next()` and `prev()` methods for forward and backward iteration through the trees in a tree sequence. Its `in_range` and `out_range` attributes provide the edges that must be added or removed to form the current
81+
tree from the previous tree, along with the current tree `interval` and its sites and mutations through `site_range` and `mutation_range`.
82+
83+
A `NumbaTreeIndex` instance can be obtained from a `NumbaTreeSequence` using the `tree_index()` method. The initial state of this is of a "null" tree outside the range of the tree sequence, the first call to `next()` or `prev()` will be to the first, or last tree sequence tree respectively. After that, the `in_range` and `out_range` attributes will provide the edges that must be added or removed to form the current tree from the previous tree. For example
84+
`tree_index.in_range.order[in_range.start:in_range.stop]` will give the edge ids that are new in the current tree, and `tree_index.out_range.order[out_range.start:out_range.stop]` will give the edge ids that are no longer present in the current tree. `tree_index.site_range` and
85+
`tree_index.mutation_range` give the indexes into the tree sequences site and mutation arrays.
86+
87+
As a simple example we can calculate the number of edges in each tree in a tree sequence:
88+
89+
```{code-cell} python
90+
import numba
91+
92+
@numba.njit
93+
def edges_per_tree(numba_ts):
94+
tree_index = numba_ts.tree_index()
95+
current_num_edges = 0
96+
num_edges = []
97+
98+
# Move forward through the trees
99+
while tree_index.next():
100+
# Access current tree information
101+
in_range = tree_index.in_range
102+
out_range = tree_index.out_range
103+
104+
current_num_edges -= (out_range.stop - out_range.start)
105+
current_num_edges += (in_range.stop - in_range.start)
106+
num_edges.append(current_num_edges)
107+
return num_edges
108+
```
109+
110+
```{code-cell} python
111+
:tags: [hide-cell]
112+
# Warm up the JIT compiler
113+
edges = edges_per_tree(numba_ts)
114+
```
115+
116+
117+
```{code-cell} python
118+
import time
119+
120+
t = time.time()
121+
jit_num_edges = edges_per_tree(numba_ts)
122+
print(f"JIT Time taken: {time.time() - t:.4f} seconds")
123+
```
124+
125+
Doing the same thing with the normal `tskit` API would be much slower:
126+
127+
```{code-cell} python
128+
t = time.time()
129+
python_num_edges = []
130+
for tree in ts.trees():
131+
python_num_edges.append(tree.num_edges)
132+
print(f"Normal Time taken: {time.time() - t:.4f} seconds")
133+
134+
assert jit_num_edges == python_num_edges, "JIT and normal results do not match!"
135+
```
136+
137+
## Example - diversity calculation
138+
139+
As a more interesting example we can calculate genetic diversity (also known as pi).
140+
For this example we'll be calculating based on the distance in the tree between samples.
141+
(`mode="branch"` in the tskit API.)
142+
143+
This example also shows the style of Python code that gives best performance under `numba`
144+
JIT compilation - using simple loops and fixed-size arrays with minimal object attribute access.
145+
146+
```{code-cell} python
147+
@numba.njit
148+
def diversity(numba_ts):
149+
# Cache arrays to avoid repeated attribute access in
150+
# tight loops
151+
edge_child = numba_ts.edges_child
152+
edge_parent = numba_ts.edges_parent
153+
node_times = numba_ts.nodes_time
154+
node_flags = numba_ts.nodes_flags
155+
156+
if numba_ts.num_samples <= 1:
157+
return 0.0
158+
159+
parent = np.full(numba_ts.num_nodes, -1, dtype=np.int32)
160+
branch_length = np.zeros(numba_ts.num_nodes, dtype=np.float64)
161+
state = np.zeros(numba_ts.num_nodes, dtype=np.int32)
162+
summary = np.zeros(numba_ts.num_nodes, dtype=np.float64)
163+
164+
n = float(numba_ts.num_samples)
165+
two_over_denom = 2.0 / (n * (n - 1.0))
166+
sample_summary = 2.0 / n
167+
168+
# Retrieve this constant outside the loop
169+
# to avoid repeated attribute access
170+
NODE_IS_SAMPLE = tskit.NODE_IS_SAMPLE
171+
# Find the sample nodes and initialize their states
172+
for node in range(numba_ts.num_nodes):
173+
if node_flags[node] & NODE_IS_SAMPLE:
174+
state[node] = 1.0
175+
summary[node] = sample_summary
176+
177+
result = 0.0
178+
running_sum = 0.0
179+
tree_index = numba_ts.tree_index()
180+
181+
# Now iterate through the trees
182+
while tree_index.next():
183+
# Process the outgoing edges
184+
for j in range(tree_index.out_range.start, tree_index.out_range.stop):
185+
h = tree_index.out_range.order[j]
186+
u = edge_child[h]
187+
188+
running_sum -= branch_length[u] * summary[u]
189+
parent[u] = -1
190+
branch_length[u] = 0.0
191+
192+
u = edge_parent[h]
193+
while u != -1:
194+
running_sum -= branch_length[u] * summary[u]
195+
state[u] -= state[edge_child[h]]
196+
summary[u] = state[u] * (n - state[u]) * two_over_denom
197+
running_sum += branch_length[u] * summary[u]
198+
u = parent[u]
199+
200+
# Process the incoming edges
201+
for j in range(tree_index.in_range.start, tree_index.in_range.stop):
202+
h = tree_index.in_range.order[j]
203+
u = edge_child[h]
204+
v = edge_parent[h]
205+
206+
parent[u] = v
207+
branch_length[u] = node_times[v] - node_times[u]
208+
running_sum += branch_length[u] * summary[u]
209+
210+
u = v
211+
while u != -1:
212+
running_sum -= branch_length[u] * summary[u]
213+
state[u] += state[edge_child[h]]
214+
summary[u] = state[u] * (n - state[u]) * two_over_denom
215+
running_sum += branch_length[u] * summary[u]
216+
u = parent[u]
217+
218+
result += running_sum * (
219+
tree_index.interval[1] - tree_index.interval[0]
220+
)
221+
222+
return result / numba_ts.sequence_length
223+
```
224+
225+
```{code-cell} python
226+
:tags: [hide-cell]
227+
# Warm up the JIT
228+
d = diversity(numba_ts)
229+
```
230+
231+
```{code-cell} python
232+
t = time.time()
233+
d = diversity(numba_ts)
234+
print("Diversity:", d)
235+
print("Time taken:", time.time() - t)
236+
```
237+
238+
As this code is written for this specific diversity calculation it is even faster
239+
than the tskit C implementation, called here from Python:
240+
241+
```{code-cell} python
242+
t = time.time()
243+
d_tskit = ts.diversity(mode="branch")
244+
print("Diversity (tskit):", d_tskit)
245+
print("Time taken:", time.time() - t)
246+
```
247+
248+
249+
250+
251+
## API Reference
252+
253+
```{eval-rst}
254+
.. currentmodule:: tskit.jit.numba
255+
256+
.. autofunction:: jitwrap
257+
258+
.. autoclass:: NumbaTreeSequence
259+
:members:
260+
261+
.. autoclass:: NumbaTreeIndex
262+
:members:
263+
264+
.. autoclass:: NumbaEdgeRange
265+
:members:
266+
```

python/requirements/CI-complete/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ lshmm==0.0.8
66
msgpack==1.1.0
77
msprime==1.3.3
88
networkx==3.2.1
9+
numba==0.61.2
910
portion==2.6.0
1011
pytest==8.3.5
1112
pytest-cov==6.0.0

python/requirements/CI-docs/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,6 @@ sphinx-autodoc-typehints==2.3.0
44
sphinx-issues==5.0.0
55
sphinx-argparse==0.5.2
66
msprime==1.3.3
7+
numba==0.61.2
78
sphinx-book-theme
89
pandas==2.2.3

python/requirements/CI-tests-pip/requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@ networkx==3.2.1
1111
msgpack==1.1.0
1212
newick==1.10.0
1313
kastore==0.3.3
14-
jsonschema==4.23.0
14+
jsonschema==4.23.0
15+
numba>=0.60.0

0 commit comments

Comments
 (0)