Skip to content

Commit 667e6f2

Browse files
authored
Merge pull request #267 from hyanwong/iterative-progress
Output number of iterations using EP
2 parents 26b5081 + 2fa8dcc commit 667e6f2

File tree

3 files changed

+99
-37
lines changed

3 files changed

+99
-37
lines changed

tests/test_cli.py

Lines changed: 71 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,74 @@ def test_default_values_preprocess(self):
155155
assert args.trim_telomeres
156156

157157

158-
class TestEndToEnd:
158+
class RunCLI:
159+
def run_tsdate_cli(self, input_ts, cmd=""):
160+
with tempfile.TemporaryDirectory() as tmpdir:
161+
input_filename = pathlib.Path(tmpdir) / "input.trees"
162+
input_ts.dump(input_filename)
163+
output_filename = pathlib.Path(tmpdir) / "output.trees"
164+
full_cmd = "date " + str(input_filename) + f" {output_filename} " + cmd
165+
cli.tsdate_main(full_cmd.split())
166+
return tskit.load(output_filename)
167+
168+
169+
class TestOutput(RunCLI):
170+
"""
171+
Tests for the command-line output.
172+
"""
173+
174+
popsize = 1
175+
176+
def test_bad_method(self, capfd):
177+
bad = "bad_method"
178+
input_ts = msprime.simulate(4, random_seed=123)
179+
cmd = f"--method {bad}"
180+
with pytest.raises(SystemExit):
181+
_ = self.run_tsdate_cli(input_ts, f"{self.popsize} " + cmd)
182+
captured = capfd.readouterr()
183+
assert bad in captured.err
184+
185+
def test_no_output(self, capfd):
186+
input_ts = msprime.simulate(4, random_seed=123)
187+
_ = self.run_tsdate_cli(input_ts, f"{self.popsize}")
188+
(out, err) = capfd.readouterr()
189+
assert out == ""
190+
assert err == ""
191+
192+
def test_progress(self, capfd):
193+
input_ts = msprime.simulate(4, random_seed=123)
194+
cmd = "--method inside_outside --progress"
195+
_ = self.run_tsdate_cli(input_ts, f"{self.popsize} " + cmd)
196+
(out, err) = capfd.readouterr()
197+
assert out == ""
198+
# run_tsdate_cli print logging to stderr
199+
desc = (
200+
"Find Node Spans",
201+
"TipCount",
202+
"Calculating Node Age Variances",
203+
"Find Mixture Priors",
204+
"Inside",
205+
"Outside",
206+
"Constrain Ages",
207+
)
208+
for match in desc:
209+
assert match in err
210+
assert err.count("100%") == len(desc)
211+
assert err.count("it/s") >= len(desc)
212+
213+
def test_iterative_progress(self, capfd):
214+
input_ts = msprime.simulate(4, random_seed=123)
215+
cmd = "--method variational_gamma --mutation-rate 1e-8 --progress"
216+
_ = self.run_tsdate_cli(input_ts, f"{self.popsize} " + cmd)
217+
(out, err) = capfd.readouterr()
218+
assert out == ""
219+
# run_tsdate_cli print logging to stderr
220+
assert err.count("Expectation Propagation: 100%") == 2
221+
assert err.count("EP (iter 2, rootwards): 100%") == 1
222+
assert err.count("rootwards): 100%") == err.count("leafwards): 100%")
223+
224+
225+
class TestEndToEnd(RunCLI):
159226
"""
160227
Class to test input to CLI outputs dated tree sequences.
161228
"""
@@ -196,29 +263,16 @@ def ts_equal(self, ts1, ts2, times_equal=False):
196263
assert t1.nodes == t2.nodes
197264

198265
def verify(self, input_ts, cmd):
199-
with tempfile.TemporaryDirectory() as tmpdir:
200-
input_filename = pathlib.Path(tmpdir) / "input.trees"
201-
input_ts.dump(input_filename)
202-
output_filename = pathlib.Path(tmpdir) / "output.trees"
203-
full_cmd = "date " + str(input_filename) + f" {output_filename} " + cmd
204-
cli.tsdate_main(full_cmd.split())
205-
output_ts = tskit.load(output_filename)
266+
output_ts = self.run_tsdate_cli(input_ts, cmd)
206267
assert input_ts.num_samples == output_ts.num_samples
207268
self.ts_equal(input_ts, output_ts)
208269

209270
def compare_python_api(self, input_ts, cmd, Ne, mutation_rate, method):
210-
with tempfile.TemporaryDirectory() as tmpdir:
211-
input_filename = pathlib.Path(tmpdir) / "input.trees"
212-
input_ts.dump(input_filename)
213-
output_filename = pathlib.Path(tmpdir) / "output.trees"
214-
full_cmd = "date " + str(input_filename) + f" {output_filename} " + cmd
215-
cli.tsdate_main(full_cmd.split())
216-
output_ts = tskit.load(output_filename)
271+
output_ts = self.run_tsdate_cli(input_ts, cmd)
217272
dated_ts = tsdate.date(
218273
input_ts, population_size=Ne, mutation_rate=mutation_rate, method=method
219274
)
220-
# print(dated_ts.tables.nodes.time, output_ts.tables.nodes.time)
221-
assert np.array_equal(dated_ts.tables.nodes.time, output_ts.tables.nodes.time)
275+
assert np.array_equal(dated_ts.nodes_time, output_ts.nodes_time)
222276

223277
def test_ts(self):
224278
input_ts = msprime.simulate(10, random_seed=1)

tsdate/cli.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -134,14 +134,15 @@ def tsdate_cli_parser():
134134
)
135135
parser.add_argument(
136136
"--method",
137-
type=str,
137+
choices=["inside_outside", "maximization", "variational_gamma"],
138138
default="inside_outside",
139-
help="Specify which estimation method to use: can be \
140-
'inside_outside' (empirically better, theoretically \
141-
problematic), 'maximization' (worse empirically, especially \
142-
with a gamma approximated prior, but theoretically robust), or \
143-
'variational_gamma' (a fast experimental continuous-time \
144-
approximation). Default: 'inside_outside'",
139+
help=(
140+
"Specify which estimation method to use: "
141+
"'inside_outside' is empirically better, but theoretically problematic, "
142+
"'maximization' is worse empirically, especially with a gamma prior, but "
143+
"theoretically robust), 'variational_gamma' is a fast experimental "
144+
"continuous-time approximation). Current default: 'inside_outside'",
145+
),
145146
)
146147
parser.add_argument(
147148
"--ignore-oldest",

tsdate/core.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -992,19 +992,14 @@ def __init__(self, *args, **kwargs):
992992
)
993993
# self.factor_norm[edge.id] += ... # TODO
994994

995-
def propagate(self, *, edges, progress=None):
995+
def propagate(self, *, edges, desc=None, progress=None):
996996
"""
997997
Update approximating factor for each edge
998998
"""
999999
if progress is None:
10001000
progress = self.progress
10011001
# TODO: this will still converge if parallelized (potentially slower)
1002-
for edge in tqdm(
1003-
edges,
1004-
desc="Expectation Propagation",
1005-
total=self.ts.num_edges,
1006-
disable=not progress,
1007-
):
1002+
for edge in tqdm(edges, desc, total=self.ts.num_edges, disable=not progress):
10081003
if edge.child in self.fixednodes:
10091004
continue
10101005
if edge.parent in self.fixednodes:
@@ -1042,13 +1037,22 @@ def propagate(self, *, edges, progress=None):
10421037
# TODO not complete
10431038
self.factor_norm[edge.id] = norm_const
10441039

1045-
def iterate(self, *, progress=None, **kwargs):
1040+
def iterate(self, *, iter_num=None, progress=None):
10461041
"""
10471042
Update edge factors from leaves to root then from root to leaves,
10481043
and return approximate log marginal likelihood
10491044
"""
1050-
self.propagate(edges=self.edges_by_parent_asc(grouped=False), progress=progress)
1051-
self.propagate(edges=self.edges_by_child_desc(grouped=False), progress=progress)
1045+
desc = "Expectation Propagation"
1046+
if iter_num: # Show iteration number if not first iteration
1047+
desc = f"EP (iter {iter_num + 1:>2}, rootwards)"
1048+
self.propagate(
1049+
edges=self.edges_by_parent_asc(grouped=False), desc=desc, progress=progress
1050+
)
1051+
if iter_num:
1052+
desc = f"EP (iter {iter_num + 1:>2}, leafwards)"
1053+
self.propagate(
1054+
edges=self.edges_by_child_desc(grouped=False), desc=desc, progress=progress
1055+
)
10521056
# TODO
10531057
# marginal_lik = np.sum(self.factor_norm)
10541058
# return marginal_lik
@@ -1112,7 +1116,10 @@ def constrain_ages_topo(ts, post_mn, eps, nodes_to_date=None, progress=False):
11121116
parents_unique = np.unique(parents, return_index=True)
11131117
parent_indices = parents_unique[1][np.isin(parents_unique[0], nodes_to_date)]
11141118
for index, nd in tqdm(
1115-
enumerate(sorted(nodes_to_date)), desc="Constrain Ages", disable=not progress
1119+
enumerate(sorted(nodes_to_date)),
1120+
desc="Constrain Ages",
1121+
total=len(nodes_to_date),
1122+
disable=not progress,
11161123
):
11171124
if index + 1 != len(nodes_to_date):
11181125
children_index = np.arange(parent_indices[index], parent_indices[index + 1])
@@ -1530,8 +1537,8 @@ def variational_dates(
15301537
)
15311538

15321539
dynamic_prog = ExpectationPropagation(priors, liklhd, progress=progress)
1533-
for _ in range(max_iterations):
1534-
dynamic_prog.iterate()
1540+
for it in range(max_iterations):
1541+
dynamic_prog.iterate(iter_num=it)
15351542
posterior = dynamic_prog.posterior
15361543
tree_sequence, mn_post, _ = variational_mean_var(
15371544
tree_sequence, posterior, fixed_node_set=fixed_nodes

0 commit comments

Comments
 (0)