29
29
import msprime
30
30
import numpy as np
31
31
import pytest
32
+ import scipy
32
33
import tskit
33
34
34
35
import tsdate
@@ -46,7 +47,7 @@ def test_make_static_files(self, request):
46
47
So that we are assured of using the same tree sequence, regardless of the
47
48
version and random number generator used in msprime, we keep these
48
49
as static files and only run this function when explicitly specified, e.g. via
49
- pytest test_accuracy.py::TestAccuracy::create_static_files
50
+ pytest test_accuracy.py::TestAccuracy::test_make_static_files
50
51
"""
51
52
mu = 1e-6
52
53
Ne = 1e4
@@ -74,14 +75,22 @@ def test_make_static_files(self, request):
74
75
ts .dump (os .path .join (request .fspath .dirname , "data" , f"{ name } .trees" ))
75
76
76
77
@pytest .mark .parametrize (
77
- "ts_name,min_r2_ts,min_r2_posterior " ,
78
+ "ts_name,min_r2_ts,min_r2_unconstrained,min_spear_ts,min_spear_unconstrained " ,
78
79
[
79
- ("one_tree" , 0.94776615238 , 0.94776615238 ),
80
- ("few_trees" , 0.96605244 , 0.96605244 ),
81
- ("many_trees" , 0.92646 , 0.92646 ),
80
+ ("one_tree" , 0.98601 , 0.98601 , 0.97719 , 0.97719 ),
81
+ ("few_trees" , 0.98220 , 0.98220 , 0.97744 , 0.97744 ),
82
+ ("many_trees" , 0.93449 , 0.93449 , 0.964547 , 0.964547 ),
82
83
],
83
84
)
84
- def test_basic (self , ts_name , min_r2_ts , min_r2_posterior , request ):
85
+ def test_basic (
86
+ self ,
87
+ ts_name ,
88
+ min_r2_ts ,
89
+ min_r2_unconstrained ,
90
+ min_spear_ts ,
91
+ min_spear_unconstrained ,
92
+ request ,
93
+ ):
85
94
ts = tskit .load (
86
95
os .path .join (request .fspath .dirname , "data" , ts_name + ".trees" )
87
96
)
@@ -97,29 +106,29 @@ def test_basic(self, ts_name, min_r2_ts, min_r2_posterior, request):
97
106
dts , posteriors = tsdate .date (
98
107
ts , Ne = Ne , mutation_rate = mu , return_posteriors = True
99
108
)
109
+ # make sure we can read node metadata - old tsdate versions didn't set a schema
110
+ if dts .table_metadata_schemas .node .schema is None :
111
+ tables = dts .dump_tables ()
112
+ tables .nodes .metadata_schema = tskit .MetadataSchema .permissive_json ()
113
+ dts = tables .tree_sequence ()
114
+
100
115
# Only test nonsample node times
101
- nonsample_nodes = np .ones (ts .num_nodes , dtype = bool )
102
- nonsample_nodes [ts .samples ()] = False
116
+ nonsamples = np .ones (ts .num_nodes , dtype = bool )
117
+ nonsamples [ts .samples ()] = False
103
118
104
- # Test the tree sequence times
105
- r_sq = (
106
- np .corrcoef (
107
- np .log (ts .nodes_time [nonsample_nodes ]),
108
- np .log (dts .nodes_time [nonsample_nodes ]),
109
- )[0 , 1 ]
110
- ** 2
111
- )
112
- assert r_sq >= min_r2_ts
119
+ min_vals = {
120
+ "r_sq" : {"ts" : min_r2_ts , "unconstr" : min_r2_unconstrained },
121
+ "spearmans_r" : {"ts" : min_spear_ts , "unconstr" : min_spear_unconstrained },
122
+ }
113
123
114
- # Test the posterior means too.
115
- post_mean = np .array (
116
- [
117
- np .sum (posteriors [i ] * posteriors ["start_time" ]) / np .sum (posteriors [i ])
118
- for i in np .where (nonsample_nodes )[0 ]
119
- ]
120
- )
121
- r_sq = (
122
- np .corrcoef (np .log (ts .nodes_time [nonsample_nodes ]), np .log (post_mean ))[0 , 1 ]
123
- ** 2
124
- )
125
- assert r_sq >= min_r2_posterior
124
+ expected = ts .nodes_time [nonsamples ]
125
+ for (observed , src ) in [
126
+ (dts .nodes_time [nonsamples ], "ts" ),
127
+ ([dts .node (i ).metadata ["mn" ] for i in np .where (nonsamples )[0 ]], "unconstr" ),
128
+ ]:
129
+ # Test the tree sequence times
130
+ r_sq = np .corrcoef (expected , observed )[0 , 1 ] ** 2
131
+ assert r_sq >= min_vals ["r_sq" ][src ]
132
+
133
+ spearmans_r = scipy .stats .spearmanr (expected , observed ).correlation
134
+ assert spearmans_r >= min_vals ["spearmans_r" ][src ]
0 commit comments