@@ -155,7 +155,74 @@ def test_default_values_preprocess(self):
155
155
assert args .trim_telomeres
156
156
157
157
158
- class TestEndToEnd :
158
+ class RunCLI :
159
+ def run_tsdate_cli (self , input_ts , cmd = "" ):
160
+ with tempfile .TemporaryDirectory () as tmpdir :
161
+ input_filename = pathlib .Path (tmpdir ) / "input.trees"
162
+ input_ts .dump (input_filename )
163
+ output_filename = pathlib .Path (tmpdir ) / "output.trees"
164
+ full_cmd = "date " + str (input_filename ) + f" { output_filename } " + cmd
165
+ cli .tsdate_main (full_cmd .split ())
166
+ return tskit .load (output_filename )
167
+
168
+
169
+ class TestOutput (RunCLI ):
170
+ """
171
+ Tests for the command-line output.
172
+ """
173
+
174
+ popsize = 1
175
+
176
+ def test_bad_method (self , capfd ):
177
+ bad = "bad_method"
178
+ input_ts = msprime .simulate (4 , random_seed = 123 )
179
+ cmd = f"--method { bad } "
180
+ with pytest .raises (SystemExit ):
181
+ _ = self .run_tsdate_cli (input_ts , f"{ self .popsize } " + cmd )
182
+ captured = capfd .readouterr ()
183
+ assert bad in captured .err
184
+
185
+ def test_no_output (self , capfd ):
186
+ input_ts = msprime .simulate (4 , random_seed = 123 )
187
+ _ = self .run_tsdate_cli (input_ts , f"{ self .popsize } " )
188
+ (out , err ) = capfd .readouterr ()
189
+ assert out == ""
190
+ assert err == ""
191
+
192
+ def test_progress (self , capfd ):
193
+ input_ts = msprime .simulate (4 , random_seed = 123 )
194
+ cmd = "--method inside_outside --progress"
195
+ _ = self .run_tsdate_cli (input_ts , f"{ self .popsize } " + cmd )
196
+ (out , err ) = capfd .readouterr ()
197
+ assert out == ""
198
+ # run_tsdate_cli print logging to stderr
199
+ desc = (
200
+ "Find Node Spans" ,
201
+ "TipCount" ,
202
+ "Calculating Node Age Variances" ,
203
+ "Find Mixture Priors" ,
204
+ "Inside" ,
205
+ "Outside" ,
206
+ "Constrain Ages" ,
207
+ )
208
+ for match in desc :
209
+ assert match in err
210
+ assert err .count ("100%" ) == len (desc )
211
+ assert err .count ("it/s" ) >= len (desc )
212
+
213
+ def test_iterative_progress (self , capfd ):
214
+ input_ts = msprime .simulate (4 , random_seed = 123 )
215
+ cmd = "--method variational_gamma --mutation-rate 1e-8 --progress"
216
+ _ = self .run_tsdate_cli (input_ts , f"{ self .popsize } " + cmd )
217
+ (out , err ) = capfd .readouterr ()
218
+ assert out == ""
219
+ # run_tsdate_cli print logging to stderr
220
+ assert err .count ("Expectation Propagation: 100%" ) == 2
221
+ assert err .count ("EP (iter 2, rootwards): 100%" ) == 1
222
+ assert err .count ("rootwards): 100%" ) == err .count ("leafwards): 100%" )
223
+
224
+
225
+ class TestEndToEnd (RunCLI ):
159
226
"""
160
227
Class to test input to CLI outputs dated tree sequences.
161
228
"""
@@ -196,29 +263,16 @@ def ts_equal(self, ts1, ts2, times_equal=False):
196
263
assert t1 .nodes == t2 .nodes
197
264
198
265
def verify (self , input_ts , cmd ):
199
- with tempfile .TemporaryDirectory () as tmpdir :
200
- input_filename = pathlib .Path (tmpdir ) / "input.trees"
201
- input_ts .dump (input_filename )
202
- output_filename = pathlib .Path (tmpdir ) / "output.trees"
203
- full_cmd = "date " + str (input_filename ) + f" { output_filename } " + cmd
204
- cli .tsdate_main (full_cmd .split ())
205
- output_ts = tskit .load (output_filename )
266
+ output_ts = self .run_tsdate_cli (input_ts , cmd )
206
267
assert input_ts .num_samples == output_ts .num_samples
207
268
self .ts_equal (input_ts , output_ts )
208
269
209
270
def compare_python_api (self , input_ts , cmd , Ne , mutation_rate , method ):
210
- with tempfile .TemporaryDirectory () as tmpdir :
211
- input_filename = pathlib .Path (tmpdir ) / "input.trees"
212
- input_ts .dump (input_filename )
213
- output_filename = pathlib .Path (tmpdir ) / "output.trees"
214
- full_cmd = "date " + str (input_filename ) + f" { output_filename } " + cmd
215
- cli .tsdate_main (full_cmd .split ())
216
- output_ts = tskit .load (output_filename )
271
+ output_ts = self .run_tsdate_cli (input_ts , cmd )
217
272
dated_ts = tsdate .date (
218
273
input_ts , population_size = Ne , mutation_rate = mutation_rate , method = method
219
274
)
220
- # print(dated_ts.tables.nodes.time, output_ts.tables.nodes.time)
221
- assert np .array_equal (dated_ts .tables .nodes .time , output_ts .tables .nodes .time )
275
+ assert np .array_equal (dated_ts .nodes_time , output_ts .nodes_time )
222
276
223
277
def test_ts (self ):
224
278
input_ts = msprime .simulate (10 , random_seed = 1 )
0 commit comments