@@ -183,37 +183,44 @@ JIT compilation - using simple loops and fixed-size arrays with minimal object a
183
183
# Process the outgoing edges
184
184
for j in range(tree_index.out_range.start, tree_index.out_range.stop):
185
185
h = tree_index.out_range.order[j]
186
- u = edge_child[h]
186
+ child = edge_child[h]
187
+ child_parent = edge_parent[h]
187
188
188
- running_sum -= branch_length[u ] * summary[u ]
189
- parent[u ] = -1
190
- branch_length[u ] = 0.0
189
+ running_sum -= branch_length[child ] * summary[child ]
190
+ parent[child ] = -1
191
+ branch_length[child ] = 0.0
191
192
192
- u = edge_parent[h]
193
+ u = child_parent
194
+ parent_u = parent[u]
193
195
while u != -1:
194
196
running_sum -= branch_length[u] * summary[u]
195
- state[u] -= state[edge_child[h] ]
197
+ state[u] -= state[child ]
196
198
summary[u] = state[u] * (n - state[u]) * two_over_denom
197
199
running_sum += branch_length[u] * summary[u]
198
- u = parent[u]
200
+ u = parent_u
201
+ if u != -1:
202
+ parent_u = parent[u]
199
203
200
204
# Process the incoming edges
201
205
for j in range(tree_index.in_range.start, tree_index.in_range.stop):
202
206
h = tree_index.in_range.order[j]
203
- u = edge_child[h]
204
- v = edge_parent[h]
207
+ child = edge_child[h]
208
+ child_parent = edge_parent[h]
205
209
206
- parent[u ] = v
207
- branch_length[u ] = node_times[v ] - node_times[u ]
208
- running_sum += branch_length[u ] * summary[u ]
210
+ parent[child ] = child_parent
211
+ branch_length[child ] = node_times[child_parent ] - node_times[child ]
212
+ running_sum += branch_length[child ] * summary[child ]
209
213
210
- u = v
214
+ u = child_parent
215
+ parent_u = parent[u]
211
216
while u != -1:
212
217
running_sum -= branch_length[u] * summary[u]
213
- state[u] += state[edge_child[h] ]
218
+ state[u] += state[child ]
214
219
summary[u] = state[u] * (n - state[u]) * two_over_denom
215
220
running_sum += branch_length[u] * summary[u]
216
- u = parent[u]
221
+ u = parent_u
222
+ if u != -1:
223
+ parent_u = parent[u]
217
224
218
225
result += running_sum * (
219
226
tree_index.interval[1] - tree_index.interval[0]
0 commit comments