@@ -145,67 +145,95 @@ def windowed_tree_stat(ts, stat, windows, span_normalise=True):
145
145
146
146
147
147
def naive_branch_general_stat (
148
- ts , w , f , windows = None , polarised = False , span_normalise = True
148
+ ts , w , f , windows = None , time_windows = None , polarised = False , span_normalise = True
149
149
):
150
150
# NOTE: does not behave correctly for unpolarised stats
151
151
# with non-ancestral material.
152
152
if windows is None :
153
153
windows = [0.0 , ts .sequence_length ]
154
+ drop_time_windows = time_windows is None
155
+ if time_windows is None :
156
+ time_windows = [0.0 , np .inf ]
157
+ else :
158
+ if time_windows [0 ] != 0 :
159
+ time_windows = [0 ] + time_windows
154
160
n , k = w .shape
161
+ tw = len (time_windows ) - 1
155
162
# hack to determine m
156
163
m = len (f (w [0 ]))
157
164
total = np .sum (w , axis = 0 )
158
165
159
- sigma = np .zeros ((ts .num_trees , m ))
160
- for tree in ts .trees ():
161
- x = np .zeros ((ts .num_nodes , k ))
162
- x [ts .samples ()] = w
163
- for u in tree .nodes (order = "postorder" ):
164
- for v in tree .children (u ):
165
- x [u ] += x [v ]
166
- if polarised :
167
- s = sum (tree .branch_length (u ) * f (x [u ]) for u in tree .nodes ())
166
+ sigma = np .zeros ((ts .num_trees , tw , m ))
167
+ for j , upper_time in enumerate (time_windows [1 :]):
168
+ if np .isfinite (upper_time ):
169
+ decap_ts = ts .decapitate (upper_time )
168
170
else :
169
- s = sum (
170
- tree .branch_length (u ) * (f (x [u ]) + f (total - x [u ]))
171
- for u in tree .nodes ()
172
- )
173
- sigma [tree .index ] = s * tree .span
171
+ decap_ts = ts
172
+ assert np .all (list (ts .samples ()) == list (decap_ts .samples ()))
173
+ for tree in decap_ts .trees ():
174
+ x = np .zeros ((decap_ts .num_nodes , k ))
175
+ x [decap_ts .samples ()] = w
176
+ for u in tree .nodes (order = "postorder" ):
177
+ for v in tree .children (u ):
178
+ x [u ] += x [v ]
179
+ if polarised :
180
+ s = sum (tree .branch_length (u ) * f (x [u ]) for u in tree .nodes ())
181
+ else :
182
+ s = sum (
183
+ tree .branch_length (u ) * (f (x [u ]) + f (total - x [u ]))
184
+ for u in tree .nodes ()
185
+ )
186
+ sigma [tree .index , j , :] = s * tree .span
187
+ for j in range (1 , tw ):
188
+ sigma [:, j , :] = sigma [:, j , :] - sigma [:, j - 1 , :]
174
189
if isinstance (windows , str ) and windows == "trees" :
175
190
# need to average across the windows
176
191
if span_normalise :
177
192
for j , tree in enumerate (ts .trees ()):
178
193
sigma [j ] /= tree .span
179
- return sigma
194
+ out = sigma
180
195
else :
181
- return windowed_tree_stat (ts , sigma , windows , span_normalise = span_normalise )
196
+ out = windowed_tree_stat (ts , sigma , windows , span_normalise = span_normalise )
197
+ if drop_time_windows :
198
+ assert out .ndim == 3
199
+ out = out [:, 0 ]
200
+ return out
182
201
183
202
184
203
def branch_general_stat (
185
- ts , sample_weights , summary_func , windows = None , polarised = False , span_normalise = True
204
+ ts ,
205
+ sample_weights ,
206
+ summary_func ,
207
+ windows = None ,
208
+ time_windows = None ,
209
+ polarised = False ,
210
+ span_normalise = True ,
186
211
):
187
212
"""
188
213
Efficient implementation of the algorithm used as the basis for the
189
214
underlying C version.
190
215
"""
191
216
n , state_dim = sample_weights .shape
192
217
windows = ts .parse_windows (windows )
218
+ drop_time_windows = time_windows is None
219
+ time_windows = ts .parse_time_windows (time_windows )
193
220
num_windows = windows .shape [0 ] - 1
221
+ num_time_windows = time_windows .shape [0 ] - 1
194
222
195
223
# Determine result_dim
196
224
result_dim = len (summary_func (sample_weights [0 ]))
197
- result = np .zeros ((num_windows , result_dim ))
225
+ result = np .zeros ((num_windows , num_time_windows , result_dim ))
198
226
state = np .zeros ((ts .num_nodes , state_dim ))
199
227
state [ts .samples ()] = sample_weights
200
228
total_weight = np .sum (sample_weights , axis = 0 )
201
229
202
230
time = ts .tables .nodes .time
203
231
parent = np .zeros (ts .num_nodes , dtype = np .int32 ) - 1
204
- branch_length = np .zeros (ts .num_nodes )
232
+ branch_length = np .zeros (( num_time_windows , ts .num_nodes ) )
205
233
# The value of summary_func(u) for every node.
206
234
summary = np .zeros ((ts .num_nodes , result_dim ))
207
235
# The result for the current tree *not* weighted by span.
208
- running_sum = np .zeros (result_dim )
236
+ running_sum = np .zeros (( num_time_windows , result_dim ) )
209
237
210
238
def polarised_summary (u ):
211
239
s = summary_func (state [u ])
@@ -217,31 +245,48 @@ def polarised_summary(u):
217
245
summary [u ] = polarised_summary (u )
218
246
219
247
window_index = 0
248
+
249
+ def update_sum (u , sign ):
250
+ time_window_index = 0
251
+ if parent [u ] != - 1 :
252
+ while (
253
+ time_window_index < num_time_windows
254
+ and time_windows [time_window_index ] < time [parent [u ]]
255
+ ):
256
+ running_sum [time_window_index ] += sign * (
257
+ branch_length [time_window_index , u ] * summary [u ]
258
+ )
259
+ time_window_index += 1
260
+
220
261
for (t_left , t_right ), edges_out , edges_in in ts .edge_diffs ():
221
262
for edge in edges_out :
222
263
u = edge .child
223
- running_sum -= branch_length [ u ] * summary [ u ]
264
+ update_sum ( u , sign = - 1 )
224
265
u = edge .parent
225
266
while u != - 1 :
226
- running_sum -= branch_length [ u ] * summary [ u ]
267
+ update_sum ( u , sign = - 1 )
227
268
state [u ] -= state [edge .child ]
228
269
summary [u ] = polarised_summary (u )
229
- running_sum += branch_length [ u ] * summary [ u ]
270
+ update_sum ( u , sign = + 1 )
230
271
u = parent [u ]
231
272
parent [edge .child ] = - 1
232
- branch_length [edge .child ] = 0
273
+ for tw in range (num_time_windows ):
274
+ branch_length [tw , edge .child ] = 0
233
275
234
276
for edge in edges_in :
235
277
parent [edge .child ] = edge .parent
236
- branch_length [edge .child ] = time [edge .parent ] - time [edge .child ]
278
+ for tw in range (num_time_windows ):
279
+ branch_length [tw , edge .child ] = min (
280
+ time [edge .parent ], time_windows [tw + 1 ]
281
+ ) - max (time [edge .child ], time_windows [tw ])
237
282
u = edge .child
238
- running_sum += branch_length [ u ] * summary [ u ]
283
+ update_sum ( u , sign = + 1 )
239
284
u = edge .parent
240
285
while u != - 1 :
241
- running_sum -= branch_length [ u ] * summary [ u ]
286
+ update_sum ( u , sign = - 1 )
242
287
state [u ] += state [edge .child ]
243
288
summary [u ] = polarised_summary (u )
244
- running_sum += branch_length [ u ] * summary [ u ]
289
+ update_sum ( u , sign = + 1 )
245
290
u = parent [u ]
246
291
247
292
# Update the windows
@@ -253,7 +298,12 @@ def polarised_summary(u):
253
298
right = min (t_right , w_right )
254
299
span = right - left
255
300
assert span > 0
256
- result [window_index ] += running_sum * span
301
+ time_window_index = 0
302
+ while time_window_index < num_time_windows :
303
+ result [window_index , time_window_index ] += (
304
+ running_sum [time_window_index ] * span
305
+ )
306
+ time_window_index += 1
257
307
if w_right <= t_right :
258
308
window_index += 1
259
309
else :
@@ -263,6 +313,9 @@ def polarised_summary(u):
263
313
264
314
# print("window_index:", window_index, windows.shape)
265
315
assert window_index == windows .shape [0 ] - 1
316
+ if drop_time_windows :
317
+ assert result .ndim == 3
318
+ result = result [:, 0 ]
266
319
if span_normalise :
267
320
for j in range (num_windows ):
268
321
result [j ] /= windows [j + 1 ] - windows [j ]
@@ -322,13 +375,20 @@ def naive_site_general_stat(
322
375
323
376
324
377
def site_general_stat (
325
- ts , sample_weights , summary_func , windows = None , polarised = False , span_normalise = True
378
+ ts ,
379
+ sample_weights ,
380
+ summary_func ,
381
+ windows = None ,
382
+ time_windows = None ,
383
+ polarised = False ,
384
+ span_normalise = True ,
326
385
):
327
386
"""
328
387
Problem: 'sites' is different that the other windowing options
329
388
because if we output by site we don't want to normalize by length of the window.
330
389
Solution: we pass an argument "normalize", to the windowing function.
331
390
"""
391
+ assert time_windows is None
332
392
windows = ts .parse_windows (windows )
333
393
num_windows = windows .shape [0 ] - 1
334
394
n , state_dim = sample_weights .shape
@@ -425,12 +485,19 @@ def naive_node_general_stat(
425
485
426
486
427
487
def node_general_stat (
428
- ts , sample_weights , summary_func , windows = None , polarised = False , span_normalise = True
488
+ ts ,
489
+ sample_weights ,
490
+ summary_func ,
491
+ windows = None ,
492
+ time_windows = None ,
493
+ polarised = False ,
494
+ span_normalise = True ,
429
495
):
430
496
"""
431
497
Efficient implementation of the algorithm used as the basis for the
432
498
underlying C version.
433
499
"""
500
+ assert time_windows is None
434
501
n , state_dim = sample_weights .shape
435
502
windows = ts .parse_windows (windows )
436
503
num_windows = windows .shape [0 ] - 1
@@ -500,6 +567,7 @@ def general_stat(
500
567
sample_weights ,
501
568
summary_func ,
502
569
windows = None ,
570
+ time_windows = None ,
503
571
polarised = False ,
504
572
mode = "site" ,
505
573
span_normalise = True ,
@@ -518,6 +586,7 @@ def general_stat(
518
586
sample_weights ,
519
587
summary_func ,
520
588
windows = windows ,
589
+ time_windows = time_windows ,
521
590
polarised = polarised ,
522
591
span_normalise = span_normalise ,
523
592
)
@@ -3534,7 +3603,9 @@ class TestSitef3(Testf3, MutatedTopologyExamplesMixin):
3534
3603
############################################
3535
3604
3536
3605
3537
- def branch_f4 (ts , sample_sets , indexes , windows = None , span_normalise = True ):
3606
+ def branch_f4 (
3607
+ ts , sample_sets , indexes , windows = None , time_windows = None , span_normalise = True
3608
+ ):
3538
3609
windows = ts .parse_windows (windows )
3539
3610
out = np .zeros ((len (windows ) - 1 , len (indexes )))
3540
3611
for j in range (len (windows ) - 1 ):
@@ -3674,7 +3745,15 @@ def node_f4(ts, sample_sets, indexes, windows=None, span_normalise=True):
3674
3745
return out
3675
3746
3676
3747
3677
- def f4 (ts , sample_sets , indexes = None , windows = None , mode = "site" , span_normalise = True ):
3748
+ def f4 (
3749
+ ts ,
3750
+ sample_sets ,
3751
+ indexes = None ,
3752
+ windows = None ,
3753
+ time_windows = None ,
3754
+ mode = "site" ,
3755
+ span_normalise = True ,
3756
+ ):
3678
3757
"""
3679
3758
Patterson's f4 statistic definitions.
3680
3759
"""
@@ -6994,3 +7073,53 @@ def f_too_long(_):
6994
7073
output_dim = 1 ,
6995
7074
strict = False ,
6996
7075
)
7076
+
7077
+
7078
+ class TestTimeWindows :
7079
+
7080
+ def test_general_stat (self , four_taxa_test_case ):
7081
+ # 1.00┊ 7 ┊ ┊ ┊
7082
+ # ┊ ┏━┻━┓ ┊ ┊ ┊
7083
+ # 0.70┊ ┃ ┃ ┊ ┊ 6 ┊
7084
+ # ┊ ┃ ┃ ┊ ┊ ┏━┻━┓ ┊
7085
+ # 0.50┊ ┃ 5 ┊ 5 ┊ ┃ 5 ┊
7086
+ # ┊ ┃ ┏┻━┓ ┊ ┏━┻━┓ ┊ ┃ ┏┻━┓ ┊
7087
+ # 0.40┊ ┃ 8 ┃ ┊ 4 8 ┊ ┃ 8 ┃ ┊
7088
+ # ┊ ┃ ┏┻┓ ┃ ┊ ┏┻┓ ┏┻┓ ┊ ┃ ┏┻┓ ┃ ┊
7089
+ # 0.00┊ 0 1 3 2 ┊ 0 2 1 3 ┊ 0 1 3 2 ┊
7090
+ # 0.00 0.20 0.80 2.50
7091
+ ts = four_taxa_test_case
7092
+ true_x = np .array (
7093
+ [
7094
+ [
7095
+ [
7096
+ 0.2 * (1 + 0.5 + 0.4 )
7097
+ + (0.8 - 0.2 ) * (1 + 0.8 )
7098
+ + (2.5 - 0.8 ) * (1.0 + 0.5 + 0.4 )
7099
+ ],
7100
+ [0.2 * 1.0 + 0 + (2.5 - 0.8 ) * 0.4 ],
7101
+ ]
7102
+ ]
7103
+ )
7104
+
7105
+ n = ts .num_samples
7106
+
7107
+ def f (x ):
7108
+ return (x > 0 ) * (1 - x / n )
7109
+
7110
+ W = np .ones ((ts .num_samples , 1 ))
7111
+ x = naive_branch_general_stat (
7112
+ ts , W , f , time_windows = [0 , 0.5 , 2.0 ], span_normalise = False
7113
+ )
7114
+ np .testing .assert_allclose (x , true_x )
7115
+
7116
+ x0 = branch_general_stat (ts , W , f , time_windows = None , span_normalise = False )
7117
+ x1 = naive_branch_general_stat (
7118
+ ts , W , f , time_windows = None , span_normalise = False
7119
+ )
7120
+ np .testing .assert_allclose (x0 , x1 )
7121
+ x_tw = branch_general_stat (
7122
+ ts , W , f , time_windows = [0 , 0.5 , 2.0 ], span_normalise = False
7123
+ )
7124
+
7125
+ np .testing .assert_allclose (x , x_tw )
0 commit comments