Skip to content

Commit aeb3fcf

Browse files
Merge pull request #499 from jeromekelleher/deletion-sites
Switch map-deletions to require explicit list
2 parents 41f9a9f + e8389a0 commit aeb3fcf

File tree

4 files changed

+71
-101
lines changed

4 files changed

+71
-101
lines changed

sc2ts/cli.py

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -589,28 +589,16 @@ def minimise_metadata(
589589
@click.command()
590590
@click.argument("dataset", type=click.Path(exists=True, dir_okay=False))
591591
@click.argument("ts_in", type=click.Path(exists=True, dir_okay=False))
592+
@click.argument("sites", type=click.Path(exists=True, dir_okay=False))
592593
@click.argument("ts_out", type=click.Path(exists=False, dir_okay=False))
593-
@click.option(
594-
"--frequency-threshold",
595-
type=float,
596-
default=0.01,
597-
help="Frequency threshold for deletions to get mapped back",
598-
)
599-
@click.option(
600-
"--mutations-threshold",
601-
type=int,
602-
default=None,
603-
help="Maximum number of mutations at a site after parsimony",
604-
)
605594
@click.option("--progress/--no-progress", default=True)
606595
@click.option("-v", "--verbose", count=True)
607596
@click.option("-l", "--log-file", default=None, type=click.Path(dir_okay=False))
608597
def map_deletions(
609-
ts_in,
610598
dataset,
599+
ts_in,
600+
sites,
611601
ts_out,
612-
frequency_threshold,
613-
mutations_threshold,
614602
progress,
615603
verbose,
616604
log_file,
@@ -621,13 +609,8 @@ def map_deletions(
621609
setup_logging(verbose, log_file)
622610
ds = sc2ts.Dataset(dataset)
623611
ts = tszip.load(ts_in)
624-
ts = sc2ts.map_deletions(
625-
ts,
626-
ds,
627-
frequency_threshold=frequency_threshold,
628-
mutations_threshold=mutations_threshold,
629-
show_progress=progress,
630-
)
612+
sites = np.loadtxt(sites, dtype=int)
613+
ts = sc2ts.map_deletions(ts, ds, sites, show_progress=progress)
631614
ts.dump(ts_out)
632615

633616

sc2ts/inference.py

Lines changed: 34 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1954,38 +1954,13 @@ def get_recombinant_strains(ts):
19541954
return ret
19551955

19561956

1957-
def map_deletions(
1958-
ts, ds, *, frequency_threshold, mutations_threshold=None, show_progress=False
1959-
):
1957+
def map_deletions(ts, ds, sites, *, show_progress=False):
19601958
"""
1961-
Map deletions at sites that exceed the specified frequency threshold onto the
1962-
ARG using parsimony (excluding flanks), and insert the mutations for any sites
1963-
that have less than the specified threshold number of mutations.
1959+
Map deletions at the specified set of site positions ARG using parsimony.
19641960
"""
1965-
mutations_threshold = 2**64 if mutations_threshold is None else mutations_threshold
19661961
start_time = time.time() # wall time
1967-
genes = core.get_gene_coordinates()
1968-
start = genes["ORF1ab"][0]
1969-
end = genes["ORF10"][1]
1970-
1962+
sites = np.array(sites, dtype=int)
19711963
md = ts.metadata
1972-
num_samples = ts.num_samples
1973-
ts_contains_exact_matches = md["sc2ts"].get("includes_exact_matches")
1974-
if ts_contains_exact_matches:
1975-
total_exact_matches = sum(
1976-
md["sc2ts"]["cumulative_stats"]["exact_matches"]["pango"].values()
1977-
)
1978-
logger.info(
1979-
f"Exact matches included; adjusting num_samples to remove {total_exact_matches}"
1980-
)
1981-
num_samples -= total_exact_matches
1982-
1983-
del_sites = []
1984-
for site in ts.sites():
1985-
if start <= site.position < end:
1986-
deletion_samples = site.metadata["sc2ts"]["deletion_samples"]
1987-
if deletion_samples / ts.num_samples >= frequency_threshold:
1988-
del_sites.append(site.id)
19891964

19901965
sample_id = md["sc2ts"]["samples_strain"]
19911966
assert not sample_id[0].startswith("Wuhan")
@@ -1994,26 +1969,27 @@ def map_deletions(
19941969
tree = ts.first()
19951970

19961971
variants = get_progress(
1997-
ds.variants(sample_id, ts.sites_position[del_sites]),
1972+
ds.variants(sample_id, sites),
19981973
title="Map deletions",
19991974
phase="",
20001975
show_progress=show_progress,
2001-
total=len(del_sites),
1976+
total=len(sites),
20021977
)
20031978

2004-
logger.info(f"Remapping {len(del_sites)} sites")
1979+
logger.info(f"Remapping {len(sites)} sites")
20051980

20061981
mut_metadata = {"sc2ts": {"type": "post_parsimony"}}
20071982
site_metadata = {}
20081983
keep_mutations = np.ones(ts.num_mutations, dtype=bool)
20091984
for var in variants:
20101985
tree.seek(var.position)
2011-
site = ts.site(position=var.position)
1986+
try:
1987+
site = ts.site(position=var.position)
1988+
except ValueError:
1989+
logger.warning(f"No site at position {var.position}; skipping")
1990+
continue
20121991

20131992
g = mask_ambiguous(var.genotypes)
2014-
deletion_samples = site.metadata["sc2ts"]["deletion_samples"]
2015-
if not ts_contains_exact_matches:
2016-
assert deletion_samples == np.sum(g == DELETION)
20171993
_, mutations = tree.map_mutations(
20181994
g, list(var.alleles), ancestral_state=site.ancestral_state
20191995
)
@@ -2022,31 +1998,28 @@ def map_deletions(
20221998
f"Site {int(site.position)} "
20231999
f"mapped mutations = {len(mutations)}; current = {len(site.mutations)}"
20242000
)
2025-
if len(mutations) < mutations_threshold:
2026-
2027-
old_mutations = []
2028-
for mut in site.mutations:
2029-
keep_mutations[mut.id] = False
2030-
old_mutations.append(
2031-
{
2032-
"node": mut.node,
2033-
"derived_state": mut.derived_state,
2034-
"metadata": mut.metadata,
2035-
}
2036-
)
2037-
md = dict(site.metadata)
2038-
md["sc2ts"]["original_mutations"] = old_mutations
2039-
site_metadata[site.id] = md
2040-
for m in mutations:
2041-
tables.mutations.add_row(
2042-
site=site.id,
2043-
node=m.node,
2044-
derived_state=m.derived_state,
2045-
time=ts.nodes_time[m.node],
2046-
metadata=mut_metadata,
2047-
)
2048-
else:
2049-
logger.debug(f"Skipping site {int(site.position)} ")
2001+
2002+
old_mutations = []
2003+
for mut in site.mutations:
2004+
keep_mutations[mut.id] = False
2005+
old_mutations.append(
2006+
{
2007+
"node": mut.node,
2008+
"derived_state": mut.derived_state,
2009+
"metadata": mut.metadata,
2010+
}
2011+
)
2012+
md = dict(site.metadata)
2013+
md["sc2ts"]["original_mutations"] = old_mutations
2014+
site_metadata[site.id] = md
2015+
for m in mutations:
2016+
tables.mutations.add_row(
2017+
site=site.id,
2018+
node=m.node,
2019+
derived_state=m.derived_state,
2020+
time=ts.nodes_time[m.node],
2021+
metadata=mut_metadata,
2022+
)
20502023

20512024
added_mutations = len(tables.mutations) - ts.num_mutations
20522025
logger.info(
@@ -2064,11 +2037,7 @@ def map_deletions(
20642037
tables.sort()
20652038
tables.build_index()
20662039
tables.compute_mutation_parents()
2067-
params = {
2068-
"dataset": str(ds.path),
2069-
"frequency_threshold": float(frequency_threshold),
2070-
"mutations_threshold": int(mutations_threshold),
2071-
}
2040+
params = {"dataset": str(ds.path), "sites": sites.tolist()}
20722041
prov = get_provenance_dict("map_deletions", params, start_time)
20732042
tables.provenances.add_row(json.dumps(prov))
20742043
return tables.tree_sequence()

tests/test_cli.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -417,11 +417,15 @@ class TestMapDeletions:
417417
def test_example(self, tmp_path, fx_ts_map, fx_dataset):
418418
ts = fx_ts_map["2020-02-13"]
419419
out_ts_path = tmp_path / "ts.ts"
420+
del_sites_path = tmp_path / "deletion_sites.txt"
421+
with open(del_sites_path, "w") as f:
422+
print("1547 3951 3952 3953", file=f)
423+
420424
runner = ct.CliRunner()
421425
result = runner.invoke(
422426
cli.cli,
423-
f"map-deletions {fx_dataset.path} {ts.path} {out_ts_path} "
424-
"--frequency-threshold=0.0001",
427+
f"map-deletions {fx_dataset.path} {ts.path} {del_sites_path} "
428+
f"{out_ts_path} ",
425429
catch_exceptions=False,
426430
)
427431
assert result.exit_code == 0

tests/test_inference.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1398,13 +1398,14 @@ def fx_ts_exact_matches(fx_ts_map, fx_match_db):
13981398
class TestMapDeletions:
13991399
def test_example(self, fx_ts_map, fx_dataset):
14001400
ts = fx_ts_map["2020-02-13"]
1401-
new_ts = sc2ts.map_deletions(ts, fx_dataset, frequency_threshold=0.001)
1401+
sites = [1547, 3951, 3952, 3953]
1402+
new_ts = sc2ts.map_deletions(ts, fx_dataset, sites)
14021403
remapped_sites = [
14031404
j
14041405
for j in range(ts.num_sites)
14051406
if "original_mutations" in new_ts.site(j).metadata["sc2ts"]
14061407
]
1407-
assert remapped_sites == [1541, 3945, 3946, 3947]
1408+
assert remapped_sites == list(np.searchsorted(ts.sites_position, sites))
14081409

14091410
for site_id in remapped_sites:
14101411
site = new_ts.site(site_id)
@@ -1425,11 +1426,22 @@ def test_example(self, fx_ts_map, fx_dataset):
14251426
for mut in site.mutations:
14261427
assert mut.metadata["sc2ts"]["type"] == "post_parsimony"
14271428

1428-
def test_filter_all(self, fx_ts_map, fx_dataset):
1429+
def test_empty(self, fx_ts_map, fx_dataset):
14291430
ts = fx_ts_map["2020-02-13"]
1430-
new_ts = sc2ts.map_deletions(
1431-
ts, fx_dataset, frequency_threshold=0.001, mutations_threshold=0
1432-
)
1431+
new_ts = sc2ts.map_deletions(ts, fx_dataset, [])
1432+
remapped_sites = [
1433+
j
1434+
for j in range(ts.num_sites)
1435+
if "original_mutations" in new_ts.site(j).metadata["sc2ts"]
1436+
]
1437+
assert remapped_sites == []
1438+
1439+
def test_missing_site(self, fx_ts_map, fx_dataset):
1440+
ts = fx_ts_map["2020-02-13"]
1441+
missing_positions = [56, 57, 58, 59, 60]
1442+
assert len(set(missing_positions) & set(ts.sites_position.astype(int))) == 0
1443+
1444+
new_ts = sc2ts.map_deletions(ts, fx_dataset, missing_positions)
14331445
remapped_sites = [
14341446
j
14351447
for j in range(ts.num_sites)
@@ -1439,29 +1451,31 @@ def test_filter_all(self, fx_ts_map, fx_dataset):
14391451

14401452
def test_example_exact_matches(self, fx_ts_exact_matches, fx_dataset):
14411453
ts = fx_ts_exact_matches
1442-
new_ts = sc2ts.map_deletions(ts, fx_dataset, frequency_threshold=0.001)
1454+
sites = [1547, 3951, 3952, 3953]
1455+
new_ts = sc2ts.map_deletions(ts, fx_dataset, sites)
14431456
remapped_sites = [
14441457
j
14451458
for j in range(ts.num_sites)
14461459
if "original_mutations" in new_ts.site(j).metadata["sc2ts"]
14471460
]
1448-
assert remapped_sites == [1541, 3945, 3946, 3947]
1461+
assert remapped_sites == list(np.searchsorted(ts.sites_position, sites))
14491462

14501463
def test_validate(self, fx_ts_map, fx_dataset):
14511464
ts = fx_ts_map["2020-02-13"]
1452-
new_ts = sc2ts.map_deletions(ts, fx_dataset, frequency_threshold=0.001)
1465+
sites = [1547, 3951, 3952, 3953]
1466+
new_ts = sc2ts.map_deletions(ts, fx_dataset, sites)
14531467
sc2ts.validate(new_ts, fx_dataset, deletions_as_missing=False)
14541468

14551469
def test_provenance(self, fx_ts_map, fx_dataset):
14561470
ts = fx_ts_map["2020-02-13"]
1457-
tsp = sc2ts.map_deletions(ts, fx_dataset, frequency_threshold=0.125)
1471+
sites = [1547, 3951, 3952, 3953]
1472+
tsp = sc2ts.map_deletions(ts, fx_dataset, sites)
14581473
assert tsp.num_provenances == ts.num_provenances + 1
14591474
prov = tsp.provenance(-1)
14601475
assert json.loads(prov.record)["parameters"] == {
14611476
"command": "map_deletions",
14621477
"dataset": str(fx_dataset.path),
1463-
"frequency_threshold": 0.125,
1464-
"mutations_threshold": 2**64,
1478+
"sites": sites,
14651479
}
14661480

14671481

0 commit comments

Comments
 (0)