Skip to content

Commit 7737a55

Browse files
authored
Merge pull request #218 from hyanwong/dont-pass-grid
No need to pass timepoints to posterior_mean_var
2 parents 36ad10c + 60cae1c commit 7737a55

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

tests/test_functions.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1480,19 +1480,22 @@ class TestPosteriorMeanVar:
14801480

14811481
def test_posterior_mean_var(self):
14821482
ts = utility_functions.single_tree_ts_n2()
1483-
grid = np.array([0, 1.2, 2])
14841483
for distr in ("gamma", "lognorm"):
14851484
posterior, algo = TestTotalFunctionalValueTree().find_posterior(ts, distr)
1486-
ts_node_metadata, mn_post, vr_post = posterior_mean_var(ts, grid, posterior)
1485+
ts_node_metadata, mn_post, vr_post = posterior_mean_var(ts, posterior)
14871486
assert np.array_equal(
1488-
mn_post, [0, 0, np.sum(grid * posterior[2]) / np.sum(posterior[2])]
1487+
mn_post,
1488+
[
1489+
0,
1490+
0,
1491+
np.sum(posterior.timepoints * posterior[2]) / np.sum(posterior[2]),
1492+
],
14891493
)
14901494

14911495
def test_node_metadata_single_tree_n2(self):
14921496
ts = utility_functions.single_tree_ts_n2()
1493-
grid = np.array([0, 1.2, 2])
14941497
posterior, algo = TestTotalFunctionalValueTree().find_posterior(ts, "lognorm")
1495-
ts_node_metadata, mn_post, vr_post = posterior_mean_var(ts, grid, posterior)
1498+
ts_node_metadata, mn_post, vr_post = posterior_mean_var(ts, posterior)
14961499
assert json.loads(ts_node_metadata.node(2).metadata)["mn"] == mn_post[2]
14971500
assert json.loads(ts_node_metadata.node(2).metadata)["vr"] == vr_post[2]
14981501

tsdate/core.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -854,7 +854,7 @@ def outside_maximization(self, *, eps, progress=None):
854854
return self.lik.timepoints[np.array(maximized_node_times).astype("int")]
855855

856856

857-
def posterior_mean_var(ts, timepoints, posterior, *, fixed_node_set=None):
857+
def posterior_mean_var(ts, posterior, *, fixed_node_set=None):
858858
"""
859859
Mean and variance of node age in unscaled time. Fixed nodes will be given a mean
860860
of their exact time in the tree sequence, and zero variance (as long as they are
@@ -876,11 +876,10 @@ def posterior_mean_var(ts, timepoints, posterior, *, fixed_node_set=None):
876876
metadata_array = tskit.unpack_bytes(
877877
ts.tables.nodes.metadata, ts.tables.nodes.metadata_offset
878878
)
879-
timepoints = timepoints
880879
for row, node_id in zip(posterior.grid_data, posterior.nonfixed_nodes):
881-
mn_post[node_id] = np.sum(row * timepoints) / np.sum(row)
880+
mn_post[node_id] = np.sum(row * posterior.timepoints) / np.sum(row)
882881
vr_post[node_id] = np.sum(
883-
((mn_post[node_id] - (timepoints)) ** 2) * (row / np.sum(row))
882+
((mn_post[node_id] - (posterior.timepoints)) ** 2) * (row / np.sum(row))
884883
)
885884
metadata_array[node_id] = json.dumps(
886885
{"mn": mn_post[node_id], "vr": vr_post[node_id]}
@@ -1131,7 +1130,7 @@ def get_dates(
11311130
normalize=outside_normalize, ignore_oldest_root=ignore_oldest_root
11321131
)
11331132
tree_sequence, mn_post, _ = posterior_mean_var(
1134-
tree_sequence, priors.timepoints, posterior, fixed_node_set=fixed_nodes
1133+
tree_sequence, posterior, fixed_node_set=fixed_nodes
11351134
)
11361135
elif method == "maximization":
11371136
if mutation_rate is not None:

0 commit comments

Comments
 (0)