@@ -55,7 +55,7 @@ def test_default_values(self):
55
55
assert args .recombination_rate is None
56
56
assert args .epsilon == 1e-6
57
57
assert args .num_threads is None
58
- assert args .probability_space == "logarithmic"
58
+ assert args .probability_space is None # Use the defaults
59
59
assert args .method == "inside_outside"
60
60
assert not args .progress
61
61
@@ -128,16 +128,15 @@ def test_probability_space(self):
128
128
)
129
129
assert args .probability_space == "logarithmic"
130
130
131
- def test_method (self ):
131
+ @pytest .mark .parametrize (
132
+ "method" , ["inside_outside" , "maximization" , "variational_gamma" ]
133
+ )
134
+ def test_method (self , method ):
132
135
parser = cli .tsdate_cli_parser ()
133
136
args = parser .parse_args (
134
- ["date" , self .infile , self .output , "10000" , "--method" , "inside_outside" ]
135
- )
136
- assert args .method == "inside_outside"
137
- args = parser .parse_args (
138
- ["date" , self .infile , self .output , "10000" , "--method" , "maximization" ]
137
+ ["date" , self .infile , self .output , "10000" , "--method" , method ]
139
138
)
140
- assert args .method == "maximization"
139
+ assert args .method == method
141
140
142
141
def test_progress (self ):
143
142
parser = cli .tsdate_cli_parser ()
@@ -262,7 +261,10 @@ def test_method(self):
262
261
with pytest .raises (ValueError ):
263
262
self .verify (input_ts , cmd )
264
263
265
- def test_compare_python_api (self ):
264
+ @pytest .mark .parametrize (
265
+ "method" , ["inside_outside" , "maximization" , "variational_gamma" ]
266
+ )
267
+ def test_compare_python_api (self , method ):
266
268
input_ts = msprime .simulate (
267
269
100 ,
268
270
Ne = 10000 ,
@@ -271,12 +273,9 @@ def test_compare_python_api(self):
271
273
length = 2e4 ,
272
274
random_seed = 10 ,
273
275
)
274
- cmd = "10000 -m 1e-8 --method inside_outside"
275
- self .verify (input_ts , cmd )
276
- self .compare_python_api (input_ts , cmd , 10000 , 1e-8 , "inside_outside" )
277
- cmd = "10000 -m 1e-8 --method maximization"
276
+ cmd = f"10000 -m 1e-8 --method { method } "
278
277
self .verify (input_ts , cmd )
279
- self .compare_python_api (input_ts , cmd , 10000 , 1e-8 , "maximization" )
278
+ self .compare_python_api (input_ts , cmd , 10000 , 1e-8 , method )
280
279
281
280
def preprocess_compare_python_api (self , input_ts ):
282
281
with tempfile .TemporaryDirectory () as tmpdir :
0 commit comments