Skip to content

Commit 2a4ec4f

Browse files
committed
Perf tweaks to diversity (~20% drop in time)
1 parent d1eb1b4 commit 2a4ec4f

File tree

3 files changed

+48
-31
lines changed

3 files changed

+48
-31
lines changed

docs/numba.md

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -183,37 +183,44 @@ JIT compilation - using simple loops and fixed-size arrays with minimal object a
183183
# Process the outgoing edges
184184
for j in range(tree_index.out_range.start, tree_index.out_range.stop):
185185
h = tree_index.out_range.order[j]
186-
u = edge_child[h]
186+
child = edge_child[h]
187+
child_parent = edge_parent[h]
187188
188-
running_sum -= branch_length[u] * summary[u]
189-
parent[u] = -1
190-
branch_length[u] = 0.0
189+
running_sum -= branch_length[child] * summary[child]
190+
parent[child] = -1
191+
branch_length[child] = 0.0
191192
192-
u = edge_parent[h]
193+
u = child_parent
194+
parent_u = parent[u]
193195
while u != -1:
194196
running_sum -= branch_length[u] * summary[u]
195-
state[u] -= state[edge_child[h]]
197+
state[u] -= state[child]
196198
summary[u] = state[u] * (n - state[u]) * two_over_denom
197199
running_sum += branch_length[u] * summary[u]
198-
u = parent[u]
200+
u = parent_u
201+
if u != -1:
202+
parent_u = parent[u]
199203
200204
# Process the incoming edges
201205
for j in range(tree_index.in_range.start, tree_index.in_range.stop):
202206
h = tree_index.in_range.order[j]
203-
u = edge_child[h]
204-
v = edge_parent[h]
207+
child = edge_child[h]
208+
child_parent = edge_parent[h]
205209
206-
parent[u] = v
207-
branch_length[u] = node_times[v] - node_times[u]
208-
running_sum += branch_length[u] * summary[u]
210+
parent[child] = child_parent
211+
branch_length[child] = node_times[child_parent] - node_times[child]
212+
running_sum += branch_length[child] * summary[child]
209213
210-
u = v
214+
u = child_parent
215+
parent_u = parent[u]
211216
while u != -1:
212217
running_sum -= branch_length[u] * summary[u]
213-
state[u] += state[edge_child[h]]
218+
state[u] += state[child]
214219
summary[u] = state[u] * (n - state[u]) * two_over_denom
215220
running_sum += branch_length[u] * summary[u]
216-
u = parent[u]
221+
u = parent_u
222+
if u != -1:
223+
parent_u = parent[u]
217224
218225
result += running_sum * (
219226
tree_index.interval[1] - tree_index.interval[0]

python/tests/test_jit.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -207,37 +207,44 @@ def diversity(numba_ts):
207207
# Process the outgoing edges
208208
for j in range(tree_index.out_range.start, tree_index.out_range.stop):
209209
h = tree_index.out_range.order[j]
210-
u = edge_child[h]
210+
child = edge_child[h]
211+
child_parent = edge_parent[h]
211212

212-
running_sum -= branch_length[u] * summary[u]
213-
parent[u] = -1
214-
branch_length[u] = 0.0
213+
running_sum -= branch_length[child] * summary[child]
214+
parent[child] = -1
215+
branch_length[child] = 0.0
215216

216-
u = edge_parent[h]
217+
u = child_parent
218+
parent_u = parent[u]
217219
while u != -1:
218220
running_sum -= branch_length[u] * summary[u]
219-
state[u] -= state[edge_child[h]]
221+
state[u] -= state[child]
220222
summary[u] = state[u] * (n - state[u]) * two_over_denom
221223
running_sum += branch_length[u] * summary[u]
222-
u = parent[u]
224+
u = parent_u
225+
if u != -1:
226+
parent_u = parent[u]
223227

224228
# Process the incoming edges
225229
for j in range(tree_index.in_range.start, tree_index.in_range.stop):
226230
h = tree_index.in_range.order[j]
227-
u = edge_child[h]
228-
v = edge_parent[h]
231+
child = edge_child[h]
232+
child_parent = edge_parent[h]
229233

230-
parent[u] = v
231-
branch_length[u] = node_times[v] - node_times[u]
232-
running_sum += branch_length[u] * summary[u]
234+
parent[child] = child_parent
235+
branch_length[child] = node_times[child_parent] - node_times[child]
236+
running_sum += branch_length[child] * summary[child]
233237

234-
u = v
238+
u = child_parent
239+
parent_u = parent[u]
235240
while u != -1:
236241
running_sum -= branch_length[u] * summary[u]
237-
state[u] += state[edge_child[h]]
242+
state[u] += state[child]
238243
summary[u] = state[u] * (n - state[u]) * two_over_denom
239244
running_sum += branch_length[u] * summary[u]
240-
u = parent[u]
245+
u = parent_u
246+
if u != -1:
247+
parent_u = parent[u]
241248

242249
result += running_sum * (tree_index.interval[1] - tree_index.interval[0])
243250

python/tskit/jit/numba.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717

1818
FORWARD = 1 #: Direction constant for forward tree traversal
1919
REVERSE = -1 #: Direction constant for reverse tree traversal
20+
21+
# Retrieve these here to avoid lookups in tight loops
2022
NODE_IS_SAMPLE = tskit.NODE_IS_SAMPLE
23+
NULL = tskit.NULL
2124

2225
edge_range_spec = [
2326
("start", numba.int32),
@@ -98,7 +101,7 @@ class NumbaTreeIndex:
98101
def __init__(self, ts):
99102
self.ts = ts
100103
self.index = -1
101-
self.direction = tskit.NULL
104+
self.direction = NULL
102105
self.interval = (0, 0)
103106
self.in_range = NumbaEdgeRange(0, 0, np.zeros(0, dtype=np.int32))
104107
self.out_range = NumbaEdgeRange(0, 0, np.zeros(0, dtype=np.int32))

0 commit comments

Comments
 (0)