Skip to content

Commit 25f79f5

Browse files
hyanwongbenjeffery
authored andcommitted
Add --keep-intermediates CLI option
1 parent d66eef9 commit 25f79f5

File tree

2 files changed

+93
-21
lines changed

2 files changed

+93
-21
lines changed

tests/test_cli.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,42 @@ def test_augment_ancestors(self):
226226
self.verify_output(output_trees)
227227

228228

229+
class TestCommandsExtra(TestCli):
230+
"""
231+
Test miscellaneous extra options for standard commands
232+
"""
233+
234+
def test_filenames_without_keeping_intermediates(self):
235+
output_anc = os.path.join(self.tempdir.name, "test1")
236+
output_anc_ts = os.path.join(self.tempdir.name, "test2")
237+
with pytest.raises(ValueError, match="--keep-intermediates"):
238+
self.run_command(["infer", self.sample_file, "-a", output_anc])
239+
with pytest.raises(ValueError, match="--keep-intermediates"):
240+
self.run_command(["infer", self.sample_file, "-A", output_anc_ts])
241+
242+
def test_keep_intermediates(self):
243+
output_anc = os.path.join(self.tempdir.name, "test1")
244+
output_anc_ts = os.path.join(self.tempdir.name, "test2")
245+
self.run_command(
246+
[
247+
"infer",
248+
self.sample_file,
249+
"--keep-intermediates",
250+
"-a",
251+
output_anc,
252+
"-A",
253+
output_anc_ts,
254+
]
255+
)
256+
assert os.path.exists(output_anc)
257+
ancestors = tsinfer.load(output_anc)
258+
assert ancestors.num_ancestors > 0
259+
260+
assert os.path.exists(output_anc_ts)
261+
anc_ts = tskit.load(output_anc_ts)
262+
assert anc_ts.num_samples > 0
263+
264+
229265
class TestProgress(TestCli):
230266
"""
231267
Tests that we get some output when we use the progress bar.

tsinfer/cli.py

Lines changed: 57 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -162,37 +162,53 @@ def run_infer(args):
162162
"via the Python function `tsinfer.SampleData.from_tree_sequence()`)."
163163
)
164164
sample_data = tsinfer.SampleData.load(args.samples)
165-
ts = tsinfer.infer(
166-
sample_data,
167-
progress_monitor=args.progress,
168-
num_threads=args.num_threads,
169-
recombination_rate=get_recombination_map(args),
170-
mismatch_ratio=args.mismatch_ratio,
171-
record_provenance=False,
172-
)
173-
output_trees = get_output_trees_path(args.output_trees, args.samples)
174-
write_ts(ts, output_trees)
165+
if args.keep_intermediates:
166+
run_generate_ancestors(args, usage_summary=False)
167+
run_match_ancestors(args, usage_summary=False)
168+
run_match_samples(args, usage_summary=False)
169+
else:
170+
if args.ancestors is not None:
171+
raise ValueError(
172+
"Must specify --keep-intermediates to save an ancestors file"
173+
)
174+
if args.ancestors_trees is not None:
175+
raise ValueError(
176+
"Must specify --keep-intermediates to save an ancestors tree sequence"
177+
)
178+
179+
ts = tsinfer.infer(
180+
sample_data,
181+
progress_monitor=args.progress,
182+
num_threads=args.num_threads,
183+
recombination_rate=get_recombination_map(args),
184+
mismatch_ratio=args.mismatch_ratio,
185+
path_compression=not args.no_path_compression,
186+
record_provenance=False,
187+
)
188+
output_trees = get_output_trees_path(args.output_trees, args.samples)
189+
write_ts(ts, output_trees)
175190
summarise_usage()
176191

177192

178-
def run_generate_ancestors(args):
193+
def run_generate_ancestors(args, usage_summary=True):
179194
setup_logging(args)
180195
ancestors_path = get_ancestors_path(args.ancestors, args.samples)
181196
sample_data = tsinfer.SampleData.load(args.samples)
182197
tsinfer.generate_ancestors(
183198
sample_data,
184199
progress_monitor=args.progress,
185-
num_flush_threads=args.num_flush_threads,
200+
num_flush_threads=getattr(args, "num_flush_threads", 0),
186201
num_threads=args.num_threads,
187202
path=ancestors_path,
188203
record_provenance=False,
189204
)
190205
# NB: ideally we should store the cli provenance in here, but this creates
191206
# perf issues - see https://github.com/tskit-dev/tsinfer/issues/743
192-
summarise_usage()
207+
if usage_summary:
208+
summarise_usage()
193209

194210

195-
def run_match_ancestors(args):
211+
def run_match_ancestors(args, usage_summary=True):
196212
setup_logging(args)
197213
ancestors_path = get_ancestors_path(args.ancestors, args.samples)
198214
logger.info(f"Loading ancestral haplotypes from {ancestors_path}")
@@ -210,10 +226,11 @@ def run_match_ancestors(args):
210226
record_provenance=False,
211227
)
212228
write_ts(ts, ancestors_trees)
213-
summarise_usage()
229+
if usage_summary:
230+
summarise_usage()
214231

215232

216-
def run_augment_ancestors(args):
233+
def run_augment_ancestors(args, usage_summary=True):
217234
setup_logging(args)
218235

219236
sample_data = tsinfer.SampleData.load(args.samples)
@@ -241,10 +258,11 @@ def run_augment_ancestors(args):
241258
)
242259
logger.info(f"Writing output tree sequence to {output_path}")
243260
ts.dump(output_path)
244-
summarise_usage()
261+
if usage_summary:
262+
summarise_usage()
245263

246264

247-
def run_match_samples(args):
265+
def run_match_samples(args, usage_summary=True):
248266
setup_logging(args)
249267

250268
sample_data = tsinfer.SampleData.load(args.samples)
@@ -264,7 +282,8 @@ def run_match_samples(args):
264282
record_provenance=False,
265283
)
266284
write_ts(ts, output_trees)
267-
summarise_usage()
285+
if usage_summary:
286+
summarise_usage()
268287

269288

270289
def run_verify(args):
@@ -425,6 +444,19 @@ def add_num_flush_threads_argument(parser):
425444
)
426445

427446

447+
def add_keep_intermediates_argument(parser):
448+
parser.add_argument(
449+
"--keep-intermediates",
450+
"-k",
451+
action="store_true",
452+
help=(
453+
"Keep the intermediate ancestors and ancestors-tree-sequence files. "
454+
"To override the default locations where these files are saved, use the "
455+
"--ancestors and --ancestors-trees options"
456+
),
457+
)
458+
459+
428460
def get_cli_parser():
429461
top_parser = argparse.ArgumentParser(
430462
description="Command line interface for tsinfer."
@@ -525,17 +557,21 @@ def get_cli_parser():
525557
"infer",
526558
help=(
527559
"Runs the generate-ancestors, match-ancestors and match-samples "
528-
"commands without writing the intermediate files to disk. Not "
529-
"recommended for large inferences."
560+
"steps in one go. Not recommended for large inferences."
530561
),
531562
)
532563
add_samples_file_argument(parser)
533564
add_logging_arguments(parser)
534565
add_output_trees_argument(parser)
566+
add_path_compression_argument(parser)
535567
add_num_threads_argument(parser)
536568
add_progress_argument(parser)
569+
add_postprocess_argument(parser)
537570
add_recombination_arguments(parser)
538571
add_mismatch_argument(parser)
572+
add_keep_intermediates_argument(parser)
573+
add_ancestors_file_argument(parser) # Only used if keep-intermediates
574+
add_ancestors_trees_argument(parser) # Only used if keep-intermediates
539575
parser.set_defaults(runner=run_infer)
540576

541577
parser = subparsers.add_parser(

0 commit comments

Comments
 (0)