diff --git a/python/tests/test_tree_stats.py b/python/tests/test_tree_stats.py index d7b2d280a0..099faf17ff 100644 --- a/python/tests/test_tree_stats.py +++ b/python/tests/test_tree_stats.py @@ -145,44 +145,69 @@ def windowed_tree_stat(ts, stat, windows, span_normalise=True): def naive_branch_general_stat( - ts, w, f, windows=None, polarised=False, span_normalise=True + ts, w, f, windows=None, time_windows=None, polarised=False, span_normalise=True ): # NOTE: does not behave correctly for unpolarised stats # with non-ancestral material. if windows is None: windows = [0.0, ts.sequence_length] + drop_time_windows = time_windows is None + if time_windows is None: + time_windows = [0.0, np.inf] + else: + if time_windows[0] != 0: + time_windows = [0] + time_windows n, k = w.shape + tw = len(time_windows) - 1 # hack to determine m m = len(f(w[0])) total = np.sum(w, axis=0) - sigma = np.zeros((ts.num_trees, m)) - for tree in ts.trees(): - x = np.zeros((ts.num_nodes, k)) - x[ts.samples()] = w - for u in tree.nodes(order="postorder"): - for v in tree.children(u): - x[u] += x[v] - if polarised: - s = sum(tree.branch_length(u) * f(x[u]) for u in tree.nodes()) + sigma = np.zeros((ts.num_trees, tw, m)) + for j, upper_time in enumerate(time_windows[1:]): + if np.isfinite(upper_time): + decap_ts = ts.decapitate(upper_time) else: - s = sum( - tree.branch_length(u) * (f(x[u]) + f(total - x[u])) - for u in tree.nodes() - ) - sigma[tree.index] = s * tree.span + decap_ts = ts + assert np.all(list(ts.samples()) == list(decap_ts.samples())) + for tree in decap_ts.trees(): + x = np.zeros((decap_ts.num_nodes, k)) + x[decap_ts.samples()] = w + for u in tree.nodes(order="postorder"): + for v in tree.children(u): + x[u] += x[v] + if polarised: + s = sum(tree.branch_length(u) * f(x[u]) for u in tree.nodes()) + else: + s = sum( + tree.branch_length(u) * (f(x[u]) + f(total - x[u])) + for u in tree.nodes() + ) + sigma[tree.index, j, :] = s * tree.span + for j in range(1, tw): + sigma[:, j, :] = sigma[:, j, :] - sigma[:, j - 1, :] if isinstance(windows, str) and windows == "trees": # need to average across the windows if span_normalise: for j, tree in enumerate(ts.trees()): sigma[j] /= tree.span - return sigma + out = sigma else: - return windowed_tree_stat(ts, sigma, windows, span_normalise=span_normalise) + out = windowed_tree_stat(ts, sigma, windows, span_normalise=span_normalise) + if drop_time_windows: + assert out.ndim == 3 + out = out[:, 0] + return out def branch_general_stat( - ts, sample_weights, summary_func, windows=None, polarised=False, span_normalise=True + ts, + sample_weights, + summary_func, + windows=None, + time_windows=None, + polarised=False, + span_normalise=True, ): """ Efficient implementation of the algorithm used as the basis for the @@ -190,22 +215,25 @@ def branch_general_stat( """ n, state_dim = sample_weights.shape windows = ts.parse_windows(windows) + drop_time_windows = time_windows is None + time_windows = ts.parse_time_windows(time_windows) num_windows = windows.shape[0] - 1 + num_time_windows = time_windows.shape[0] - 1 # Determine result_dim result_dim = len(summary_func(sample_weights[0])) - result = np.zeros((num_windows, result_dim)) + result = np.zeros((num_windows, num_time_windows, result_dim)) state = np.zeros((ts.num_nodes, state_dim)) state[ts.samples()] = sample_weights total_weight = np.sum(sample_weights, axis=0) time = ts.tables.nodes.time parent = np.zeros(ts.num_nodes, dtype=np.int32) - 1 - branch_length = np.zeros(ts.num_nodes) + branch_length = np.zeros((num_time_windows, ts.num_nodes)) # The value of summary_func(u) for every node. summary = np.zeros((ts.num_nodes, result_dim)) # The result for the current tree *not* weighted by span. - running_sum = np.zeros(result_dim) + running_sum = np.zeros((num_time_windows, result_dim)) def polarised_summary(u): s = summary_func(state[u]) @@ -217,31 +245,48 @@ def polarised_summary(u): summary[u] = polarised_summary(u) window_index = 0 + + def update_sum(u, sign): + time_window_index = 0 + if parent[u] != -1: + while ( + time_window_index < num_time_windows + and time_windows[time_window_index] < time[parent[u]] + ): + running_sum[time_window_index] += sign * ( + branch_length[time_window_index, u] * summary[u] + ) + time_window_index += 1 + for (t_left, t_right), edges_out, edges_in in ts.edge_diffs(): for edge in edges_out: u = edge.child - running_sum -= branch_length[u] * summary[u] + update_sum(u, sign=-1) u = edge.parent while u != -1: - running_sum -= branch_length[u] * summary[u] + update_sum(u, sign=-1) state[u] -= state[edge.child] summary[u] = polarised_summary(u) - running_sum += branch_length[u] * summary[u] + update_sum(u, sign=+1) u = parent[u] parent[edge.child] = -1 - branch_length[edge.child] = 0 + for tw in range(num_time_windows): + branch_length[tw, edge.child] = 0 for edge in edges_in: parent[edge.child] = edge.parent - branch_length[edge.child] = time[edge.parent] - time[edge.child] + for tw in range(num_time_windows): + branch_length[tw, edge.child] = min( + time[edge.parent], time_windows[tw + 1] + ) - max(time[edge.child], time_windows[tw]) u = edge.child - running_sum += branch_length[u] * summary[u] + update_sum(u, sign=+1) u = edge.parent while u != -1: - running_sum -= branch_length[u] * summary[u] + update_sum(u, sign=-1) state[u] += state[edge.child] summary[u] = polarised_summary(u) - running_sum += branch_length[u] * summary[u] + update_sum(u, sign=+1) u = parent[u] # Update the windows @@ -253,7 +298,12 @@ def polarised_summary(u): right = min(t_right, w_right) span = right - left assert span > 0 - result[window_index] += running_sum * span + time_window_index = 0 + while time_window_index < num_time_windows: + result[window_index, time_window_index] += ( + running_sum[time_window_index] * span + ) + time_window_index += 1 if w_right <= t_right: window_index += 1 else: @@ -263,6 +313,9 @@ def polarised_summary(u): # print("window_index:", window_index, windows.shape) assert window_index == windows.shape[0] - 1 + if drop_time_windows: + assert result.ndim == 3 + result = result[:, 0] if span_normalise: for j in range(num_windows): result[j] /= windows[j + 1] - windows[j] @@ -322,13 +375,20 @@ def naive_site_general_stat( def site_general_stat( - ts, sample_weights, summary_func, windows=None, polarised=False, span_normalise=True + ts, + sample_weights, + summary_func, + windows=None, + time_windows=None, + polarised=False, + span_normalise=True, ): """ Problem: 'sites' is different that the other windowing options because if we output by site we don't want to normalize by length of the window. Solution: we pass an argument "normalize", to the windowing function. """ + assert time_windows is None windows = ts.parse_windows(windows) num_windows = windows.shape[0] - 1 n, state_dim = sample_weights.shape @@ -425,12 +485,19 @@ def naive_node_general_stat( def node_general_stat( - ts, sample_weights, summary_func, windows=None, polarised=False, span_normalise=True + ts, + sample_weights, + summary_func, + windows=None, + time_windows=None, + polarised=False, + span_normalise=True, ): """ Efficient implementation of the algorithm used as the basis for the underlying C version. """ + assert time_windows is None n, state_dim = sample_weights.shape windows = ts.parse_windows(windows) num_windows = windows.shape[0] - 1 @@ -500,6 +567,7 @@ def general_stat( sample_weights, summary_func, windows=None, + time_windows=None, polarised=False, mode="site", span_normalise=True, @@ -518,6 +586,7 @@ def general_stat( sample_weights, summary_func, windows=windows, + time_windows=time_windows, polarised=polarised, span_normalise=span_normalise, ) @@ -3534,7 +3603,9 @@ class TestSitef3(Testf3, MutatedTopologyExamplesMixin): ############################################ -def branch_f4(ts, sample_sets, indexes, windows=None, span_normalise=True): +def branch_f4( + ts, sample_sets, indexes, windows=None, time_windows=None, span_normalise=True +): windows = ts.parse_windows(windows) out = np.zeros((len(windows) - 1, len(indexes))) for j in range(len(windows) - 1): @@ -3674,7 +3745,15 @@ def node_f4(ts, sample_sets, indexes, windows=None, span_normalise=True): return out -def f4(ts, sample_sets, indexes=None, windows=None, mode="site", span_normalise=True): +def f4( + ts, + sample_sets, + indexes=None, + windows=None, + time_windows=None, + mode="site", + span_normalise=True, +): """ Patterson's f4 statistic definitions. """ @@ -6994,3 +7073,53 @@ def f_too_long(_): output_dim=1, strict=False, ) + + +class TestTimeWindows: + + def test_general_stat(self, four_taxa_test_case): + # 1.00┊ 7 ┊ ┊ ┊ + # ┊ ┏━┻━┓ ┊ ┊ ┊ + # 0.70┊ ┃ ┃ ┊ ┊ 6 ┊ + # ┊ ┃ ┃ ┊ ┊ ┏━┻━┓ ┊ + # 0.50┊ ┃ 5 ┊ 5 ┊ ┃ 5 ┊ + # ┊ ┃ ┏┻━┓ ┊ ┏━┻━┓ ┊ ┃ ┏┻━┓ ┊ + # 0.40┊ ┃ 8 ┃ ┊ 4 8 ┊ ┃ 8 ┃ ┊ + # ┊ ┃ ┏┻┓ ┃ ┊ ┏┻┓ ┏┻┓ ┊ ┃ ┏┻┓ ┃ ┊ + # 0.00┊ 0 1 3 2 ┊ 0 2 1 3 ┊ 0 1 3 2 ┊ + # 0.00 0.20 0.80 2.50 + ts = four_taxa_test_case + true_x = np.array( + [ + [ + [ + 0.2 * (1 + 0.5 + 0.4) + + (0.8 - 0.2) * (1 + 0.8) + + (2.5 - 0.8) * (1.0 + 0.5 + 0.4) + ], + [0.2 * 1.0 + 0 + (2.5 - 0.8) * 0.4], + ] + ] + ) + + n = ts.num_samples + + def f(x): + return (x > 0) * (1 - x / n) + + W = np.ones((ts.num_samples, 1)) + x = naive_branch_general_stat( + ts, W, f, time_windows=[0, 0.5, 2.0], span_normalise=False + ) + np.testing.assert_allclose(x, true_x) + + x0 = branch_general_stat(ts, W, f, time_windows=None, span_normalise=False) + x1 = naive_branch_general_stat( + ts, W, f, time_windows=None, span_normalise=False + ) + np.testing.assert_allclose(x0, x1) + x_tw = branch_general_stat( + ts, W, f, time_windows=[0, 0.5, 2.0], span_normalise=False + ) + + np.testing.assert_allclose(x, x_tw)