Skip to content

Commit 77aa230

Browse files
authored
Merge pull request #220 from hyanwong/accuracy-test
Accuracy test
2 parents 7737a55 + d7eb3ae commit 77aa230

File tree

6 files changed

+160
-1
lines changed

6 files changed

+160
-1
lines changed

tests/conftest.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import pytest
2+
3+
4+
def pytest_addoption(parser):
5+
parser.addoption(
6+
"--make-files",
7+
action="store_true",
8+
default=False,
9+
help="generate static tree sequences used for testing",
10+
)
11+
12+
13+
def pytest_collection_modifyitems(config, items):
14+
if config.getoption("--make-files"):
15+
# --make-files given in cli: only run tests marked @pytest.mark.makefiles
16+
skip_normal_tests = pytest.mark.skip(
17+
reason="--make-files specified, so other tests skipped"
18+
)
19+
for item in items:
20+
if "makefiles" not in item.keywords:
21+
item.add_marker(skip_normal_tests)
22+
else:
23+
skip_make_files = pytest.mark.skip(
24+
reason="specify --make-files to (re)create various files used for testing"
25+
)
26+
for item in items:
27+
if "makefiles" in item.keywords:
28+
item.add_marker(skip_make_files)
29+
30+
31+
def pytest_configure(config):
32+
config.addinivalue_line(
33+
"markers", "makefiles: mark test to run only when --make-files option given"
34+
)

tests/data/few_trees.trees

15.4 KB
Binary file not shown.

tests/data/many_trees.trees

22.1 KB
Binary file not shown.

tests/data/one_tree.trees

14.4 KB
Binary file not shown.

tests/test_accuracy.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# MIT License
2+
#
3+
# Copyright (c) 2022 University of Oxford
4+
#
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy
6+
# of this software and associated documentation files (the "Software"), to deal
7+
# in the Software without restriction, including without limitation the rights
8+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
# copies of the Software, and to permit persons to whom the Software is
10+
# furnished to do so, subject to the following conditions:
11+
#
12+
# The above copyright notice and this permission notice shall be included in
13+
# all
14+
# copies or substantial portions of the Software.
15+
#
16+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22+
# SOFTWARE.
23+
"""
24+
Test cases for tsdate accuracy.
25+
"""
26+
import json
27+
import os
28+
29+
import msprime
30+
import numpy as np
31+
import pytest
32+
import tskit
33+
34+
import tsdate
35+
36+
37+
class TestAccuracy:
38+
"""
39+
Test for some of the basic functions used in tsdate
40+
"""
41+
42+
@pytest.mark.makefiles
43+
def test_make_static_files(self, request):
44+
"""
45+
The function used to create the tree sequences for accuracy testing.
46+
So that we are assured of using the same tree sequence, regardless of the
47+
version and random number generator used in msprime, we keep these
48+
as static files and only run this function when explicitly specified, e.g. via
49+
pytest test_accuracy.py::TestAccuracy::create_static_files
50+
"""
51+
mu = 1e-6
52+
Ne = 1e4
53+
seed = 123
54+
for name, rho in zip(
55+
["one_tree", "few_trees", "many_trees"],
56+
[0, 7e-9, 1.3e-7], # Chosen to give 1, 2, and 25 trees
57+
):
58+
ts = msprime.sim_ancestry(
59+
10,
60+
population_size=Ne,
61+
sequence_length=1e3,
62+
recombination_rate=rho,
63+
random_seed=seed,
64+
)
65+
if name != "one_tree":
66+
assert ts.num_trees > 1
67+
if name == "few_trees":
68+
assert ts.num_trees < 5
69+
if name == "many_trees":
70+
assert ts.num_trees >= 20
71+
72+
ts = msprime.sim_mutations(ts, rate=mu, random_seed=seed)
73+
assert ts.num_mutations > 100
74+
ts.dump(os.path.join(request.fspath.dirname, "data", f"{name}.trees"))
75+
76+
@pytest.mark.parametrize(
77+
"ts_name,min_r2_ts,min_r2_posterior",
78+
[
79+
("one_tree", 0.94776615238, 0.94776615238),
80+
("few_trees", 0.96605244, 0.96605244),
81+
("many_trees", 0.92646, 0.92646),
82+
],
83+
)
84+
def test_basic(self, ts_name, min_r2_ts, min_r2_posterior, request):
85+
ts = tskit.load(
86+
os.path.join(request.fspath.dirname, "data", ts_name + ".trees")
87+
)
88+
89+
sim_ancestry_parameters = json.loads(ts.provenance(0).record)["parameters"]
90+
assert sim_ancestry_parameters["command"] == "sim_ancestry"
91+
Ne = sim_ancestry_parameters["population_size"]
92+
93+
sim_mutations_parameters = json.loads(ts.provenance(1).record)["parameters"]
94+
assert sim_mutations_parameters["command"] == "sim_mutations"
95+
mu = sim_mutations_parameters["rate"]
96+
97+
dts, posteriors = tsdate.date(
98+
ts, Ne=Ne, mutation_rate=mu, return_posteriors=True
99+
)
100+
# Only test nonsample node times
101+
nonsample_nodes = np.ones(ts.num_nodes, dtype=bool)
102+
nonsample_nodes[ts.samples()] = False
103+
104+
# Test the tree sequence times
105+
r_sq = (
106+
np.corrcoef(
107+
np.log(ts.nodes_time[nonsample_nodes]),
108+
np.log(dts.nodes_time[nonsample_nodes]),
109+
)[0, 1]
110+
** 2
111+
)
112+
assert r_sq >= min_r2_ts
113+
114+
# Test the posterior means too.
115+
post_mean = np.array(
116+
[
117+
np.sum(posteriors[i] * posteriors["start_time"]) / np.sum(posteriors[i])
118+
for i in np.where(nonsample_nodes)[0]
119+
]
120+
)
121+
r_sq = (
122+
np.corrcoef(np.log(ts.nodes_time[nonsample_nodes]), np.log(post_mean))[0, 1]
123+
** 2
124+
)
125+
assert r_sq >= min_r2_posterior

tsdate/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def force_probability_space(self, probability_space):
119119
if self.probability_space == LOG:
120120
pass
121121
elif self.probability_space == LIN:
122-
with np.errstate(divide="ignore"):
122+
with np.errstate(divide="ignore", invalid="ignore"):
123123
self.grid_data = np.log(self.grid_data)
124124
self.fixed_data = np.log(self.fixed_data)
125125
self.probability_space = LOG

0 commit comments

Comments
 (0)