Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 163 additions & 34 deletions python/tests/test_tree_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,67 +145,95 @@ def windowed_tree_stat(ts, stat, windows, span_normalise=True):


def naive_branch_general_stat(
ts, w, f, windows=None, polarised=False, span_normalise=True
ts, w, f, windows=None, time_windows=None, polarised=False, span_normalise=True
):
# NOTE: does not behave correctly for unpolarised stats
# with non-ancestral material.
if windows is None:
windows = [0.0, ts.sequence_length]
drop_time_windows = time_windows is None
if time_windows is None:
time_windows = [0.0, np.inf]
else:
if time_windows[0] != 0:
time_windows = [0] + time_windows
n, k = w.shape
tw = len(time_windows) - 1
# hack to determine m
m = len(f(w[0]))
total = np.sum(w, axis=0)

sigma = np.zeros((ts.num_trees, m))
for tree in ts.trees():
x = np.zeros((ts.num_nodes, k))
x[ts.samples()] = w
for u in tree.nodes(order="postorder"):
for v in tree.children(u):
x[u] += x[v]
if polarised:
s = sum(tree.branch_length(u) * f(x[u]) for u in tree.nodes())
sigma = np.zeros((ts.num_trees, tw, m))
for j, upper_time in enumerate(time_windows[1:]):
if np.isfinite(upper_time):
decap_ts = ts.decapitate(upper_time)
else:
s = sum(
tree.branch_length(u) * (f(x[u]) + f(total - x[u]))
for u in tree.nodes()
)
sigma[tree.index] = s * tree.span
decap_ts = ts
assert np.all(list(ts.samples()) == list(decap_ts.samples()))
for tree in decap_ts.trees():
x = np.zeros((decap_ts.num_nodes, k))
x[decap_ts.samples()] = w
for u in tree.nodes(order="postorder"):
for v in tree.children(u):
x[u] += x[v]
if polarised:
s = sum(tree.branch_length(u) * f(x[u]) for u in tree.nodes())
else:
s = sum(
tree.branch_length(u) * (f(x[u]) + f(total - x[u]))
for u in tree.nodes()
)
sigma[tree.index, j, :] = s * tree.span
for j in range(1, tw):
sigma[:, j, :] = sigma[:, j, :] - sigma[:, j - 1, :]
if isinstance(windows, str) and windows == "trees":
# need to average across the windows
if span_normalise:
for j, tree in enumerate(ts.trees()):
sigma[j] /= tree.span
return sigma
out = sigma
else:
return windowed_tree_stat(ts, sigma, windows, span_normalise=span_normalise)
out = windowed_tree_stat(ts, sigma, windows, span_normalise=span_normalise)
if drop_time_windows:
assert out.ndim == 3
out = out[:, 0]
return out


def branch_general_stat(
ts, sample_weights, summary_func, windows=None, polarised=False, span_normalise=True
ts,
sample_weights,
summary_func,
windows=None,
time_windows=None,
polarised=False,
span_normalise=True,
):
"""
Efficient implementation of the algorithm used as the basis for the
underlying C version.
"""
n, state_dim = sample_weights.shape
windows = ts.parse_windows(windows)
drop_time_windows = time_windows is None
time_windows = ts.parse_time_windows(time_windows)
num_windows = windows.shape[0] - 1
num_time_windows = time_windows.shape[0] - 1

# Determine result_dim
result_dim = len(summary_func(sample_weights[0]))
result = np.zeros((num_windows, result_dim))
result = np.zeros((num_windows, num_time_windows, result_dim))
state = np.zeros((ts.num_nodes, state_dim))
state[ts.samples()] = sample_weights
total_weight = np.sum(sample_weights, axis=0)

time = ts.tables.nodes.time
parent = np.zeros(ts.num_nodes, dtype=np.int32) - 1
branch_length = np.zeros(ts.num_nodes)
branch_length = np.zeros((num_time_windows, ts.num_nodes))
# The value of summary_func(u) for every node.
summary = np.zeros((ts.num_nodes, result_dim))
# The result for the current tree *not* weighted by span.
running_sum = np.zeros(result_dim)
running_sum = np.zeros((num_time_windows, result_dim))

def polarised_summary(u):
s = summary_func(state[u])
Expand All @@ -217,31 +245,48 @@ def polarised_summary(u):
summary[u] = polarised_summary(u)

window_index = 0

def update_sum(u, sign):
time_window_index = 0
if parent[u] != -1:
while (
time_window_index < num_time_windows
and time_windows[time_window_index] < time[parent[u]]
):
running_sum[time_window_index] += sign * (
branch_length[time_window_index, u] * summary[u]
)
time_window_index += 1

for (t_left, t_right), edges_out, edges_in in ts.edge_diffs():
for edge in edges_out:
u = edge.child
running_sum -= branch_length[u] * summary[u]
update_sum(u, sign=-1)
u = edge.parent
while u != -1:
running_sum -= branch_length[u] * summary[u]
update_sum(u, sign=-1)
state[u] -= state[edge.child]
summary[u] = polarised_summary(u)
running_sum += branch_length[u] * summary[u]
update_sum(u, sign=+1)
u = parent[u]
parent[edge.child] = -1
branch_length[edge.child] = 0
for tw in range(num_time_windows):
branch_length[tw, edge.child] = 0

for edge in edges_in:
parent[edge.child] = edge.parent
branch_length[edge.child] = time[edge.parent] - time[edge.child]
for tw in range(num_time_windows):
branch_length[tw, edge.child] = min(
time[edge.parent], time_windows[tw + 1]
) - max(time[edge.child], time_windows[tw])
u = edge.child
running_sum += branch_length[u] * summary[u]
update_sum(u, sign=+1)
u = edge.parent
while u != -1:
running_sum -= branch_length[u] * summary[u]
update_sum(u, sign=-1)
state[u] += state[edge.child]
summary[u] = polarised_summary(u)
running_sum += branch_length[u] * summary[u]
update_sum(u, sign=+1)
u = parent[u]

# Update the windows
Expand All @@ -253,7 +298,12 @@ def polarised_summary(u):
right = min(t_right, w_right)
span = right - left
assert span > 0
result[window_index] += running_sum * span
time_window_index = 0
while time_window_index < num_time_windows:
result[window_index, time_window_index] += (
running_sum[time_window_index] * span
)
time_window_index += 1
if w_right <= t_right:
window_index += 1
else:
Expand All @@ -263,6 +313,9 @@ def polarised_summary(u):

# print("window_index:", window_index, windows.shape)
assert window_index == windows.shape[0] - 1
if drop_time_windows:
assert result.ndim == 3
result = result[:, 0]
if span_normalise:
for j in range(num_windows):
result[j] /= windows[j + 1] - windows[j]
Expand Down Expand Up @@ -322,13 +375,20 @@ def naive_site_general_stat(


def site_general_stat(
ts, sample_weights, summary_func, windows=None, polarised=False, span_normalise=True
ts,
sample_weights,
summary_func,
windows=None,
time_windows=None,
polarised=False,
span_normalise=True,
):
"""
Problem: 'sites' is different that the other windowing options
because if we output by site we don't want to normalize by length of the window.
Solution: we pass an argument "normalize", to the windowing function.
"""
assert time_windows is None
windows = ts.parse_windows(windows)
num_windows = windows.shape[0] - 1
n, state_dim = sample_weights.shape
Expand Down Expand Up @@ -425,12 +485,19 @@ def naive_node_general_stat(


def node_general_stat(
ts, sample_weights, summary_func, windows=None, polarised=False, span_normalise=True
ts,
sample_weights,
summary_func,
windows=None,
time_windows=None,
polarised=False,
span_normalise=True,
):
"""
Efficient implementation of the algorithm used as the basis for the
underlying C version.
"""
assert time_windows is None
n, state_dim = sample_weights.shape
windows = ts.parse_windows(windows)
num_windows = windows.shape[0] - 1
Expand Down Expand Up @@ -500,6 +567,7 @@ def general_stat(
sample_weights,
summary_func,
windows=None,
time_windows=None,
polarised=False,
mode="site",
span_normalise=True,
Expand All @@ -518,6 +586,7 @@ def general_stat(
sample_weights,
summary_func,
windows=windows,
time_windows=time_windows,
polarised=polarised,
span_normalise=span_normalise,
)
Expand Down Expand Up @@ -3534,7 +3603,9 @@ class TestSitef3(Testf3, MutatedTopologyExamplesMixin):
############################################


def branch_f4(ts, sample_sets, indexes, windows=None, span_normalise=True):
def branch_f4(
ts, sample_sets, indexes, windows=None, time_windows=None, span_normalise=True
):
windows = ts.parse_windows(windows)
out = np.zeros((len(windows) - 1, len(indexes)))
for j in range(len(windows) - 1):
Expand Down Expand Up @@ -3674,7 +3745,15 @@ def node_f4(ts, sample_sets, indexes, windows=None, span_normalise=True):
return out


def f4(ts, sample_sets, indexes=None, windows=None, mode="site", span_normalise=True):
def f4(
ts,
sample_sets,
indexes=None,
windows=None,
time_windows=None,
mode="site",
span_normalise=True,
):
"""
Patterson's f4 statistic definitions.
"""
Expand Down Expand Up @@ -6994,3 +7073,53 @@ def f_too_long(_):
output_dim=1,
strict=False,
)


class TestTimeWindows:

def test_general_stat(self, four_taxa_test_case):
# 1.00┊ 7 ┊ ┊ ┊
# ┊ ┏━┻━┓ ┊ ┊ ┊
# 0.70┊ ┃ ┃ ┊ ┊ 6 ┊
# ┊ ┃ ┃ ┊ ┊ ┏━┻━┓ ┊
# 0.50┊ ┃ 5 ┊ 5 ┊ ┃ 5 ┊
# ┊ ┃ ┏┻━┓ ┊ ┏━┻━┓ ┊ ┃ ┏┻━┓ ┊
# 0.40┊ ┃ 8 ┃ ┊ 4 8 ┊ ┃ 8 ┃ ┊
# ┊ ┃ ┏┻┓ ┃ ┊ ┏┻┓ ┏┻┓ ┊ ┃ ┏┻┓ ┃ ┊
# 0.00┊ 0 1 3 2 ┊ 0 2 1 3 ┊ 0 1 3 2 ┊
# 0.00 0.20 0.80 2.50
ts = four_taxa_test_case
true_x = np.array(
[
[
[
0.2 * (1 + 0.5 + 0.4)
+ (0.8 - 0.2) * (1 + 0.8)
+ (2.5 - 0.8) * (1.0 + 0.5 + 0.4)
],
[0.2 * 1.0 + 0 + (2.5 - 0.8) * 0.4],
]
]
)

n = ts.num_samples

def f(x):
return (x > 0) * (1 - x / n)

W = np.ones((ts.num_samples, 1))
x = naive_branch_general_stat(
ts, W, f, time_windows=[0, 0.5, 2.0], span_normalise=False
)
np.testing.assert_allclose(x, true_x)

x0 = branch_general_stat(ts, W, f, time_windows=None, span_normalise=False)
x1 = naive_branch_general_stat(
ts, W, f, time_windows=None, span_normalise=False
)
np.testing.assert_allclose(x0, x1)
x_tw = branch_general_stat(
ts, W, f, time_windows=[0, 0.5, 2.0], span_normalise=False
)

np.testing.assert_allclose(x, x_tw)
Loading