Skip to content

Commit 26b5081

Browse files
authored
Merge pull request #261 from hyanwong/cli-variational
Allow cli to work with variational_gamma
2 parents 115aa43 + eb44891 commit 26b5081

File tree

4 files changed

+71
-28
lines changed

4 files changed

+71
-28
lines changed

tests/test_cli.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def test_default_values(self):
5555
assert args.recombination_rate is None
5656
assert args.epsilon == 1e-6
5757
assert args.num_threads is None
58-
assert args.probability_space == "logarithmic"
58+
assert args.probability_space is None # Use the defaults
5959
assert args.method == "inside_outside"
6060
assert not args.progress
6161

@@ -128,16 +128,15 @@ def test_probability_space(self):
128128
)
129129
assert args.probability_space == "logarithmic"
130130

131-
def test_method(self):
131+
@pytest.mark.parametrize(
132+
"method", ["inside_outside", "maximization", "variational_gamma"]
133+
)
134+
def test_method(self, method):
132135
parser = cli.tsdate_cli_parser()
133136
args = parser.parse_args(
134-
["date", self.infile, self.output, "10000", "--method", "inside_outside"]
135-
)
136-
assert args.method == "inside_outside"
137-
args = parser.parse_args(
138-
["date", self.infile, self.output, "10000", "--method", "maximization"]
137+
["date", self.infile, self.output, "10000", "--method", method]
139138
)
140-
assert args.method == "maximization"
139+
assert args.method == method
141140

142141
def test_progress(self):
143142
parser = cli.tsdate_cli_parser()
@@ -262,7 +261,10 @@ def test_method(self):
262261
with pytest.raises(ValueError):
263262
self.verify(input_ts, cmd)
264263

265-
def test_compare_python_api(self):
264+
@pytest.mark.parametrize(
265+
"method", ["inside_outside", "maximization", "variational_gamma"]
266+
)
267+
def test_compare_python_api(self, method):
266268
input_ts = msprime.simulate(
267269
100,
268270
Ne=10000,
@@ -271,12 +273,9 @@ def test_compare_python_api(self):
271273
length=2e4,
272274
random_seed=10,
273275
)
274-
cmd = "10000 -m 1e-8 --method inside_outside"
275-
self.verify(input_ts, cmd)
276-
self.compare_python_api(input_ts, cmd, 10000, 1e-8, "inside_outside")
277-
cmd = "10000 -m 1e-8 --method maximization"
276+
cmd = f"10000 -m 1e-8 --method {method}"
278277
self.verify(input_ts, cmd)
279-
self.compare_python_api(input_ts, cmd, 10000, 1e-8, "maximization")
278+
self.compare_python_api(input_ts, cmd, 10000, 1e-8, method)
280279

281280
def preprocess_compare_python_api(self, input_ts):
282281
with tempfile.TemporaryDirectory() as tmpdir:

tests/test_functions.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from tsdate.core import Likelihoods
4747
from tsdate.core import LogLikelihoods
4848
from tsdate.core import posterior_mean_var
49+
from tsdate.core import variational_dates
4950
from tsdate.demography import PopulationSizeHistory
5051
from tsdate.prior import ConditionalCoalescentTimes
5152
from tsdate.prior import fill_priors
@@ -1531,6 +1532,25 @@ def test_bad_Ne(self):
15311532
tsdate.build_prior_grid(ts, population_size=-10)
15321533

15331534

1535+
class TestCallingErrors:
1536+
def test_bad_vgamma_probability_space(self):
1537+
ts = utility_functions.single_tree_ts_n2()
1538+
with pytest.raises(ValueError, match="Cannot specify"):
1539+
variational_dates(ts, 1, 1, probability_space=base.LOG)
1540+
1541+
def test_bad_vgamma_num_threads(self):
1542+
# Test can be removed if we specify num_threads in the future
1543+
ts = utility_functions.single_tree_ts_n2()
1544+
with pytest.raises(ValueError, match="does not currently"):
1545+
variational_dates(ts, 1, 1, num_threads=2)
1546+
1547+
def test_bad_vgamma_ignore_oldest_root(self):
1548+
# Test can be removed in the future if this is implemented
1549+
ts = utility_functions.single_tree_ts_n2()
1550+
with pytest.raises(ValueError, match="not implemented"):
1551+
variational_dates(ts, 1, 1, ignore_oldest_root=True)
1552+
1553+
15341554
class TestPosteriorMeanVar:
15351555
"""
15361556
Test posterior_mean_var works as expected

tsdate/cli.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -125,26 +125,29 @@ def tsdate_cli_parser():
125125
parser.add_argument(
126126
"--probability-space",
127127
type=str,
128-
default="logarithmic",
128+
default=None,
129129
help="Should the internal algorithm save probabilities in \
130130
'logarithmic' (slower, less liable to to overflow) or 'linear' \
131-
space (faster, may overflow). Default: 'logarithmic'",
131+
space (faster, may overflow). Not relevant for the \
132+
'variational_gamma' method; default otherwise is `None` \
133+
currently treated as 'logarithmic'",
132134
)
133135
parser.add_argument(
134136
"--method",
135137
type=str,
136138
default="inside_outside",
137139
help="Specify which estimation method to use: can be \
138140
'inside_outside' (empirically better, theoretically \
139-
problematic) or 'maximization' (worse empirically, especially \
140-
with a gamma approximated prior, but theoretically robust). \
141-
Default: 'inside_outside'",
141+
problematic), 'maximization' (worse empirically, especially \
142+
with a gamma approximated prior, but theoretically robust), or \
143+
'variational_gamma' (a fast experimental continuous-time \
144+
approximation). Default: 'inside_outside'",
142145
)
143146
parser.add_argument(
144147
"--ignore-oldest",
145148
action="store_true",
146-
help="Ignore the oldest node in the tree sequence, which is \
147-
often of low quality when using empirical data.",
149+
help="Ignore the oldest node in the tree sequence: in older tsinfer versions \
150+
this could be of low quality when using empirical data.",
148151
)
149152
parser.add_argument(
150153
"-p", "--progress", action="store_true", help="Show progress bar."
@@ -190,18 +193,18 @@ def run_date(args):
190193
ts = tskit.load(args.tree_sequence)
191194
except tskit.FileFormatError as ffe:
192195
error_exit(f"Error loading '{args.tree_sequence}: {ffe}")
193-
dated_ts = tsdate.date(
194-
ts,
195-
args.mutation_rate,
196-
args.population_size,
196+
params = dict(
197197
recombination_rate=args.recombination_rate,
198-
probability_space=args.probability_space,
199198
method=args.method,
200199
eps=args.epsilon,
200+
progress=args.progress,
201+
probability_space=args.probability_space,
201202
num_threads=args.num_threads,
202203
ignore_oldest_root=args.ignore_oldest,
203-
progress=args.progress,
204204
)
205+
# TODO: error out if ignore_oldest_root is set,
206+
# see https://github.com/tskit-dev/tsdate/issues/262
207+
dated_ts = tsdate.date(ts, args.mutation_rate, args.population_size, **params)
205208
dated_ts.dump(args.output)
206209

207210

tsdate/core.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1287,7 +1287,7 @@ def get_dates(
12871287
ignore_oldest_root=False,
12881288
progress=False,
12891289
cache_inside=False,
1290-
probability_space=base.LOG,
1290+
probability_space=None,
12911291
):
12921292
"""
12931293
Infer dates for the nodes in a tree sequence, returning an array of inferred dates
@@ -1333,6 +1333,9 @@ def get_dates(
13331333
)
13341334
priors = priors
13351335

1336+
if probability_space is None:
1337+
probability_space = base.LOG
1338+
13361339
if probability_space != base.LOG:
13371340
liklhd = Likelihoods(
13381341
tree_sequence,
@@ -1447,6 +1450,9 @@ def variational_dates(
14471450
global_prior=True,
14481451
eps=1e-6,
14491452
progress=False,
1453+
num_threads=None, # Unused, matches get_dates()
1454+
probability_space=None, # Can only be None, simply to match get_dates()
1455+
ignore_oldest_root=False, # Can only be False, simply to match get_dates()
14501456
):
14511457
"""
14521458
Infer dates for the nodes in a tree sequence using expectation propagation,
@@ -1469,6 +1475,21 @@ def variational_dates(
14691475
if not max_iterations >= 1:
14701476
raise ValueError("Maximum number of iterations must be greater than 0")
14711477

1478+
# Parameters below are not used in variational dating, but are here
1479+
# to match the signature of get_dates(). We may be able to remove some
1480+
# if we move to specifying some params via a control dictionary
1481+
1482+
if probability_space is not None:
1483+
raise ValueError("Cannot specify a probability space in variational dating")
1484+
1485+
if num_threads is not None and num_threads != 1:
1486+
raise ValueError("Variational dating does not currently use multiple threads")
1487+
1488+
if ignore_oldest_root:
1489+
raise ValueError(
1490+
"Ignoring the oldes root is not implemented in variational dating"
1491+
)
1492+
14721493
# Default to not creating approximate priors unless ts has > 1000 samples
14731494
approx_priors = False
14741495
if tree_sequence.num_samples > 1000:

0 commit comments

Comments
 (0)