Skip to content

Commit 85a51f3

Browse files
committed
Add spearman's R to the accuracy tests
And use better r2 estimates
1 parent b0ea6a1 commit 85a51f3

File tree

1 file changed

+38
-29
lines changed

1 file changed

+38
-29
lines changed

tests/test_accuracy.py

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import msprime
3030
import numpy as np
3131
import pytest
32+
import scipy
3233
import tskit
3334

3435
import tsdate
@@ -46,7 +47,7 @@ def test_make_static_files(self, request):
4647
So that we are assured of using the same tree sequence, regardless of the
4748
version and random number generator used in msprime, we keep these
4849
as static files and only run this function when explicitly specified, e.g. via
49-
pytest test_accuracy.py::TestAccuracy::create_static_files
50+
pytest test_accuracy.py::TestAccuracy::test_make_static_files
5051
"""
5152
mu = 1e-6
5253
Ne = 1e4
@@ -74,14 +75,22 @@ def test_make_static_files(self, request):
7475
ts.dump(os.path.join(request.fspath.dirname, "data", f"{name}.trees"))
7576

7677
@pytest.mark.parametrize(
77-
"ts_name,min_r2_ts,min_r2_posterior",
78+
"ts_name,min_r2_ts,min_r2_unconstrained,min_spear_ts,min_spear_unconstrained",
7879
[
79-
("one_tree", 0.94776615238, 0.94776615238),
80-
("few_trees", 0.96605244, 0.96605244),
81-
("many_trees", 0.92646, 0.92646),
80+
("one_tree", 0.98601, 0.98601, 0.97719, 0.97719),
81+
("few_trees", 0.98220, 0.98220, 0.97744, 0.97744),
82+
("many_trees", 0.93449, 0.93449, 0.964547, 0.964547),
8283
],
8384
)
84-
def test_basic(self, ts_name, min_r2_ts, min_r2_posterior, request):
85+
def test_basic(
86+
self,
87+
ts_name,
88+
min_r2_ts,
89+
min_r2_unconstrained,
90+
min_spear_ts,
91+
min_spear_unconstrained,
92+
request,
93+
):
8594
ts = tskit.load(
8695
os.path.join(request.fspath.dirname, "data", ts_name + ".trees")
8796
)
@@ -97,29 +106,29 @@ def test_basic(self, ts_name, min_r2_ts, min_r2_posterior, request):
97106
dts, posteriors = tsdate.date(
98107
ts, Ne=Ne, mutation_rate=mu, return_posteriors=True
99108
)
109+
# make sure we can read node metadata - old tsdate versions didn't set a schema
110+
if dts.table_metadata_schemas.node.schema is None:
111+
tables = dts.dump_tables()
112+
tables.nodes.metadata_schema = tskit.MetadataSchema.permissive_json()
113+
dts = tables.tree_sequence()
114+
100115
# Only test nonsample node times
101-
nonsample_nodes = np.ones(ts.num_nodes, dtype=bool)
102-
nonsample_nodes[ts.samples()] = False
116+
nonsamples = np.ones(ts.num_nodes, dtype=bool)
117+
nonsamples[ts.samples()] = False
103118

104-
# Test the tree sequence times
105-
r_sq = (
106-
np.corrcoef(
107-
np.log(ts.nodes_time[nonsample_nodes]),
108-
np.log(dts.nodes_time[nonsample_nodes]),
109-
)[0, 1]
110-
** 2
111-
)
112-
assert r_sq >= min_r2_ts
119+
min_vals = {
120+
"r_sq": {"ts": min_r2_ts, "unconstr": min_r2_unconstrained},
121+
"spearmans_r": {"ts": min_spear_ts, "unconstr": min_spear_unconstrained},
122+
}
113123

114-
# Test the posterior means too.
115-
post_mean = np.array(
116-
[
117-
np.sum(posteriors[i] * posteriors["start_time"]) / np.sum(posteriors[i])
118-
for i in np.where(nonsample_nodes)[0]
119-
]
120-
)
121-
r_sq = (
122-
np.corrcoef(np.log(ts.nodes_time[nonsample_nodes]), np.log(post_mean))[0, 1]
123-
** 2
124-
)
125-
assert r_sq >= min_r2_posterior
124+
expected = ts.nodes_time[nonsamples]
125+
for (observed, src) in [
126+
(dts.nodes_time[nonsamples], "ts"),
127+
([dts.node(i).metadata["mn"] for i in np.where(nonsamples)[0]], "unconstr"),
128+
]:
129+
# Test the tree sequence times
130+
r_sq = np.corrcoef(expected, observed)[0, 1] ** 2
131+
assert r_sq >= min_vals["r_sq"][src]
132+
133+
spearmans_r = scipy.stats.spearmanr(expected, observed).correlation
134+
assert spearmans_r >= min_vals["spearmans_r"][src]

0 commit comments

Comments
 (0)