Skip to content

Commit 6fec400

Browse files
committed
Update tests
1 parent f2a344b commit 6fec400

File tree

10 files changed

+126
-2111
lines changed

10 files changed

+126
-2111
lines changed

1

Lines changed: 0 additions & 2021 deletions
This file was deleted.

tests/test_accuracy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def test_basic(
104104
mu = sim_mutations_parameters["rate"]
105105

106106
dts, posteriors = tsdate.date(
107-
ts, Ne=Ne, mutation_rate=mu, return_posteriors=True
107+
ts, population_size=Ne, mutation_rate=mu, return_posteriors=True
108108
)
109109
# make sure we can read node metadata - old tsdate versions didn't set a schema
110110
if dts.table_metadata_schemas.node.schema is None:
@@ -139,6 +139,6 @@ def test_scaling(self, Ne):
139139
Test that we are in the right theoretical ballpark given known Ne
140140
"""
141141
ts = tskit.Tree.generate_comb(2).tree_sequence
142-
dts = tsdate.date(ts, Ne=Ne, mutation_rate=None)
142+
dts = tsdate.date(ts, population_size=Ne, mutation_rate=None)
143143
# Check the date is within 10% of the expected
144144
assert 0.9 < dts.node(dts.first().root).time / (2 * Ne) < 1.1

tests/test_cache.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,13 @@ def test_cached_prior(self):
2727
if os.path.isfile(fn):
2828
raise AssertionError(f"The file {fn} already exists. Delete before testing")
2929
with self.assertLogs(level="WARNING") as log:
30-
priors_approx10 = ConditionalCoalescentTimes(10, Ne=1)
30+
priors_approx10 = ConditionalCoalescentTimes(10)
3131
assert len(log.output) == 1
3232
assert "user cache" in log.output[0]
3333
priors_approx10.add(10)
3434
# Check we have created the prior file
3535
assert os.path.isfile(fn)
36-
priors_approxNone = ConditionalCoalescentTimes(None, Ne=1)
36+
priors_approxNone = ConditionalCoalescentTimes(None)
3737
priors_approxNone.add(10)
3838
assert np.allclose(priors_approx10[10], priors_approxNone[10], equal_nan=True)
3939
# Test when using a bigger n that we're using the precalculated version

tests/test_cli.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def test_default_values(self):
5050
args = parser.parse_args(["date", self.infile, self.output, "1"])
5151
assert args.tree_sequence == self.infile
5252
assert args.output == self.output
53-
assert args.Ne == 1
53+
assert args.population_size == 1
5454
assert args.mutation_rate is None
5555
assert args.recombination_rate is None
5656
assert args.epsilon == 1e-6
@@ -216,7 +216,7 @@ def compare_python_api(self, input_ts, cmd, Ne, mutation_rate, method):
216216
cli.tsdate_main(full_cmd.split())
217217
output_ts = tskit.load(output_filename)
218218
dated_ts = tsdate.date(
219-
input_ts, Ne=Ne, mutation_rate=mutation_rate, method=method
219+
input_ts, population_size=Ne, mutation_rate=mutation_rate, method=method
220220
)
221221
print(dated_ts.tables.nodes.time, output_ts.tables.nodes.time)
222222
assert np.array_equal(dated_ts.tables.nodes.time, output_ts.tables.nodes.time)

tests/test_functions.py

Lines changed: 36 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ class TestMakePrior:
265265
# We only test make_prior() on single trees
266266
def verify_priors(self, ts, prior_distr):
267267
# Check prior contains all possible tips
268-
priors = ConditionalCoalescentTimes(None, Ne=0.5, prior_distr=prior_distr)
268+
priors = ConditionalCoalescentTimes(None, prior_distr=prior_distr)
269269
priors.add(ts.num_samples)
270270
priors_df = priors[ts.num_samples]
271271
assert priors_df.shape[0] == ts.num_samples + 1
@@ -406,7 +406,7 @@ class TestMixturePrior:
406406

407407
def get_mixture_prior_params(self, ts, prior_distr, **kwargs):
408408
span_data = SpansBySamples(ts, **kwargs)
409-
priors = ConditionalCoalescentTimes(None, Ne=0.5, prior_distr=prior_distr)
409+
priors = ConditionalCoalescentTimes(None, prior_distr=prior_distr)
410410
priors.add(ts.num_samples, approximate=False)
411411
mixture_priors = priors.get_mixture_prior_params(span_data)
412412
return mixture_priors
@@ -512,8 +512,8 @@ def test_two_tree_mutation_ts_intervals(self):
512512
class TestPriorVals:
513513
def verify_prior_vals(self, ts, prior_distr, **kwargs):
514514
span_data = SpansBySamples(ts, **kwargs)
515-
Ne = 0.5
516-
priors = ConditionalCoalescentTimes(None, Ne=Ne, prior_distr=prior_distr)
515+
Ne = np.array([[0, 0.5]])
516+
priors = ConditionalCoalescentTimes(None, prior_distr=prior_distr)
517517
priors.add(ts.num_samples, approximate=False)
518518
grid = np.linspace(0, 3, 3)
519519
mixture_priors = priors.get_mixture_prior_params(span_data)
@@ -1102,8 +1102,8 @@ def run_outside_algorithm(
11021102
self, ts, prior_distr="lognorm", standardize=False, ignore_oldest_root=False
11031103
):
11041104
span_data = SpansBySamples(ts)
1105-
Ne = 0.5
1106-
priors = ConditionalCoalescentTimes(None, Ne, prior_distr)
1105+
Ne = np.array([[0, 0.5]])
1106+
priors = ConditionalCoalescentTimes(None, prior_distr)
11071107
priors.add(ts.num_samples, approximate=False)
11081108
grid = np.array([0, 1.2, 2])
11091109
mixture_priors = priors.get_mixture_prior_params(span_data)
@@ -1205,8 +1205,8 @@ class TestTotalFunctionalValueTree:
12051205
def find_posterior(self, ts, prior_distr):
12061206
grid = np.array([0, 1.2, 2])
12071207
span_data = SpansBySamples(ts)
1208-
Ne = 0.5
1209-
priors = ConditionalCoalescentTimes(None, Ne=Ne, prior_distr=prior_distr)
1208+
Ne = np.array([[0, 0.5]])
1209+
priors = ConditionalCoalescentTimes(None, prior_distr=prior_distr)
12101210
priors.add(ts.num_samples, approximate=False)
12111211
mixture_priors = priors.get_mixture_prior_params(span_data)
12121212
prior_vals = fill_priors(mixture_priors, grid, ts, Ne, prior_distr=prior_distr)
@@ -1269,13 +1269,13 @@ def test_gil_tree(self):
12691269
ts = utility_functions.gils_example_tree()
12701270
span_data = SpansBySamples(ts)
12711271
prior_distr = "lognorm"
1272-
Ne = 0.5
1273-
priors = ConditionalCoalescentTimes(None, Ne, prior_distr=prior_distr)
1272+
Ne = np.array([[0, 0.5]])
1273+
priors = ConditionalCoalescentTimes(None, prior_distr=prior_distr)
12741274
priors.add(ts.num_samples, approximate=False)
12751275
grid = np.array([0, 0.1, 0.2, 0.5, 1, 2, 5])
12761276
mixture_prior = priors.get_mixture_prior_params(span_data)
12771277
prior_vals = fill_priors(
1278-
mixture_prior, grid, ts, 1, prior_distr=prior_distr
1278+
mixture_prior, grid, ts, Ne, prior_distr=prior_distr
12791279
)
12801280
prior_vals.grid_data[0] = [0, 0.5, 0.3, 0.1, 0.05, 0.02, 0.03]
12811281
prior_vals.grid_data[1] = [0, 0.05, 0.1, 0.2, 0.45, 0.1, 0.1]
@@ -1458,28 +1458,30 @@ class TestDate:
14581458
def test_date_input(self):
14591459
ts = utility_functions.single_tree_ts_n2()
14601460
with pytest.raises(ValueError):
1461-
tsdate.date(ts, mutation_rate=None, Ne=1, method="foobar")
1461+
tsdate.date(ts, mutation_rate=None, population_size=1, method="foobar")
14621462

14631463
def test_sample_as_parent_fails(self):
14641464
ts = utility_functions.single_tree_ts_n3_sample_as_parent()
14651465
with pytest.raises(NotImplementedError):
1466-
tsdate.date(ts, mutation_rate=None, Ne=1)
1466+
tsdate.date(ts, mutation_rate=None, population_size=1)
14671467

14681468
def test_recombination_not_implemented(self):
14691469
ts = utility_functions.single_tree_ts_n2()
14701470
with pytest.raises(NotImplementedError):
1471-
tsdate.date(ts, mutation_rate=None, Ne=1, recombination_rate=1e-8)
1471+
tsdate.date(
1472+
ts, mutation_rate=None, population_size=1, recombination_rate=1e-8
1473+
)
14721474

14731475
def test_Ne_and_priors(self):
14741476
ts = utility_functions.single_tree_ts_n2()
14751477
with pytest.raises(ValueError):
1476-
priors = tsdate.build_prior_grid(ts, Ne=1)
1477-
tsdate.date(ts, mutation_rate=None, Ne=1, priors=priors)
1478+
priors = tsdate.build_prior_grid(ts, population_size=1)
1479+
tsdate.date(ts, mutation_rate=None, population_size=1, priors=priors)
14781480

14791481
def test_no_Ne_priors(self):
14801482
ts = utility_functions.single_tree_ts_n2()
14811483
with pytest.raises(ValueError):
1482-
tsdate.date(ts, mutation_rate=None, Ne=None, priors=None)
1484+
tsdate.date(ts, mutation_rate=None, population_size=None, priors=None)
14831485

14841486

14851487
class TestBuildPriorGrid:
@@ -1512,7 +1514,7 @@ def test_bad_prior_distr(self):
15121514
def test_bad_Ne(self):
15131515
ts = msprime.simulate(2, random_seed=12)
15141516
with pytest.raises(ValueError):
1515-
tsdate.build_prior_grid(ts, Ne=-10)
1517+
tsdate.build_prior_grid(ts, population_size=-10)
15161518

15171519

15181520
class TestPosteriorMeanVar:
@@ -1545,8 +1547,10 @@ def test_node_metadata_simulated_tree(self):
15451547
larger_ts = msprime.simulate(
15461548
10, mutation_rate=1, recombination_rate=1, length=20, random_seed=12
15471549
)
1548-
_, mn_post, _, _, eps, _ = get_dates(larger_ts, mutation_rate=None, Ne=10000)
1549-
dated_ts = date(larger_ts, Ne=10000, mutation_rate=None)
1550+
_, mn_post, _, _, eps, _ = get_dates(
1551+
larger_ts, mutation_rate=None, population_size=10000
1552+
)
1553+
dated_ts = date(larger_ts, population_size=10000, mutation_rate=None)
15501554
metadata = dated_ts.tables.nodes.metadata
15511555
metadata_offset = dated_ts.tables.nodes.metadata_offset
15521556
unconstrained_mn = [
@@ -1709,7 +1713,7 @@ def test_node_times(self):
17091713
larger_ts = msprime.simulate(
17101714
10, mutation_rate=1, recombination_rate=1, length=20, random_seed=12
17111715
)
1712-
dated = date(larger_ts, mutation_rate=None, Ne=10000)
1716+
dated = date(larger_ts, mutation_rate=None, population_size=10000)
17131717
node_ages = nodes_time_unconstrained(dated)
17141718
assert np.all(dated.tables.nodes.time[:] >= node_ages)
17151719

@@ -1736,8 +1740,8 @@ def test_node_selection_param(self):
17361740

17371741
def test_sites_time_insideoutside(self):
17381742
ts = utility_functions.two_tree_mutation_ts()
1739-
dated = tsdate.date(ts, mutation_rate=None, Ne=1)
1740-
_, mn_post, _, _, eps, _ = get_dates(ts, mutation_rate=None, Ne=1)
1743+
dated = tsdate.date(ts, mutation_rate=None, population_size=1)
1744+
_, mn_post, _, _, eps, _ = get_dates(ts, mutation_rate=None, population_size=1)
17411745
assert np.array_equal(
17421746
mn_post[ts.tables.mutations.node],
17431747
tsdate.sites_time_from_ts(dated, unconstrained=True, min_time=0),
@@ -1749,15 +1753,17 @@ def test_sites_time_insideoutside(self):
17491753

17501754
def test_sites_time_maximization(self):
17511755
ts = utility_functions.two_tree_mutation_ts()
1752-
dated = tsdate.date(ts, Ne=1, mutation_rate=1, method="maximization")
1756+
dated = tsdate.date(
1757+
ts, population_size=1, mutation_rate=1, method="maximization"
1758+
)
17531759
assert np.array_equal(
17541760
dated.tables.nodes.time[ts.tables.mutations.node],
17551761
tsdate.sites_time_from_ts(dated, unconstrained=False, min_time=0),
17561762
)
17571763

17581764
def test_sites_time_node_selection(self):
17591765
ts = utility_functions.two_tree_mutation_ts()
1760-
dated = tsdate.date(ts, Ne=1, mutation_rate=1)
1766+
dated = tsdate.date(ts, population_size=1, mutation_rate=1)
17611767
sites_time_child = tsdate.sites_time_from_ts(
17621768
dated, node_selection="child", min_time=0
17631769
)
@@ -1838,8 +1844,10 @@ def test_sites_time_simulated(self):
18381844
larger_ts = msprime.simulate(
18391845
10, mutation_rate=1, recombination_rate=1, length=20, random_seed=12
18401846
)
1841-
_, mn_post, _, _, _, _ = get_dates(larger_ts, mutation_rate=None, Ne=10000)
1842-
dated = date(larger_ts, mutation_rate=None, Ne=10000)
1847+
_, mn_post, _, _, _, _ = get_dates(
1848+
larger_ts, mutation_rate=None, population_size=10000
1849+
)
1850+
dated = date(larger_ts, mutation_rate=None, population_size=10000)
18431851
assert np.array_equal(
18441852
mn_post[larger_ts.tables.mutations.node],
18451853
tsdate.sites_time_from_ts(dated, unconstrained=True, min_time=0),
@@ -1944,7 +1952,7 @@ def test_historical_samples(self):
19441952
ts.simplify(ts.samples(time=0), filter_sites=False)
19451953
)
19461954
inferred_ts = tsinfer.infer(modern_samples).simplify(filter_sites=False)
1947-
dated_ts = tsdate.date(inferred_ts, Ne=1, mutation_rate=1)
1955+
dated_ts = tsdate.date(inferred_ts, population_size=1, mutation_rate=1)
19481956
site_times = tsdate.sites_time_from_ts(dated_ts)
19491957
# make a sd file with historical individual times
19501958
samples = tsinfer.SampleData.from_tree_sequence(

0 commit comments

Comments
 (0)