Skip to content

Commit 0705c54

Browse files
Merge pull request #476 from jeromekelleher/postprocess
Postprocess
2 parents fc354a1 + 6f2af11 commit 0705c54

File tree

5 files changed

+309
-2
lines changed

5 files changed

+309
-2
lines changed

sc2ts/cli.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,41 @@ def run_hmm(
543543
print(run.asjson())
544544

545545

546+
@click.command()
547+
@click.argument("ts_in", type=click.Path(exists=True, dir_okay=False))
548+
@click.argument("ts_out", type=click.Path(exists=False, dir_okay=False))
549+
@click.option("--match-db", type=click.Path(exists=True, dir_okay=False))
550+
@click.option("--progress/--no-progress", default=True)
551+
@click.option("-v", "--verbose", count=True)
552+
@click.option("-l", "--log-file", default=None, type=click.Path(dir_okay=False))
553+
def postprocess(
554+
ts_in,
555+
ts_out,
556+
match_db,
557+
progress,
558+
verbose,
559+
log_file,
560+
):
561+
"""
562+
Perform final postprocessing steps to the specified ARG.
563+
"""
564+
setup_logging(verbose, log_file)
565+
ts = tszip.load(ts_in)
566+
if match_db is not None:
567+
with sc2ts.MatchDb(match_db) as db:
568+
ts = sc2ts.append_exact_matches(ts, db, show_progress=progress)
569+
570+
ts = sc2ts.trim_metadata(ts, show_progress=progress)
571+
572+
ts = sc2ts.push_up_unary_recombinant_mutations(ts)
573+
574+
# See if we can remove some of the reversions in a straightforward way.
575+
mutations_is_reversion = sc2ts.find_reversions(ts)
576+
mutations_before = ts.num_mutations
577+
ts = sc2ts.push_up_reversions(ts, ts.mutations_node[mutations_is_reversion])
578+
ts.dump(ts_out)
579+
580+
546581
def find_previous_date_path(date, path_pattern):
547582
"""
548583
Find the path with the most-recent date to the specified one
@@ -577,6 +612,7 @@ def cli():
577612

578613
cli.add_command(infer)
579614
cli.add_command(validate)
615+
cli.add_command(postprocess)
580616
cli.add_command(run_hmm)
581617

582618
cli.add_command(tally_lineages)

sc2ts/inference.py

Lines changed: 132 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -814,7 +814,7 @@ def update_top_level_metadata(ts, date, retro_groups, samples):
814814
return tables
815815

816816

817-
def add_sample_to_tables(sample, tables, group_id=None):
817+
def add_sample_to_tables(sample, tables, group_id=None, time=0):
818818
sc2ts_md = {
819819
"hmm_match": sample.hmm_match.asdict(),
820820
"alignment_composition": dict(sample.alignment_composition),
@@ -825,7 +825,7 @@ def add_sample_to_tables(sample, tables, group_id=None):
825825
if group_id is not None:
826826
sc2ts_md["group_id"] = group_id
827827
metadata = {**sample.metadata, "sc2ts": sc2ts_md}
828-
return tables.nodes.add_row(flags=sample.flags, metadata=metadata)
828+
return tables.nodes.add_row(flags=sample.flags, metadata=metadata, time=time)
829829

830830

831831
def match_path_ts(group):
@@ -2031,3 +2031,133 @@ def map_deletions(ts, ds, *, frequency_threshold, show_progress=False):
20312031
tables.build_index()
20322032
tables.compute_mutation_parents()
20332033
return tables.tree_sequence()
2034+
2035+
2036+
def append_exact_matches(ts, match_db, show_progress=False):
2037+
"""
2038+
Update the specified tree sequence to include all exact matches
2039+
from the specified match DB.
2040+
"""
2041+
md = ts.metadata
2042+
date = md["sc2ts"]["date"]
2043+
total_exact_matches = sum(
2044+
md["sc2ts"]["cumulative_stats"]["exact_matches"]["pango"].values()
2045+
)
2046+
samples_strain = md["sc2ts"]["samples_strain"]
2047+
tables = ts.dump_tables()
2048+
L = tables.sequence_length
2049+
time_zero = parse_date(date)
2050+
with match_db.conn:
2051+
sql = f"SELECT * FROM samples WHERE hmm_cost == 0 AND match_date <= '{date}'"
2052+
rows = tqdm.tqdm(
2053+
match_db.conn.execute(sql),
2054+
total=total_exact_matches,
2055+
desc="Exact matches",
2056+
disable=not show_progress,
2057+
)
2058+
for row in rows:
2059+
pkl = row.pop("pickle")
2060+
sample = pickle.loads(bz2.decompress(pkl))
2061+
sample.flags |= core.NODE_IS_EXACT_MATCH
2062+
delta = time_zero - parse_date(sample.date)
2063+
assert delta.days >= 0
2064+
u = add_sample_to_tables(sample, tables, time=delta.days)
2065+
parent = sample.hmm_match.path[0].parent
2066+
tables.edges.add_row(0, L, parent=parent, child=u)
2067+
samples_strain.append(sample.strain)
2068+
2069+
assert total_exact_matches == len(tables.nodes) - ts.num_nodes
2070+
md["sc2ts"]["samples_strain"] = samples_strain
2071+
tables.metadata = md
2072+
tables.sort()
2073+
return tables.tree_sequence()
2074+
2075+
2076+
def trim_metadata(ts, show_progress=False):
2077+
tables = ts.dump_tables()
2078+
2079+
tables.nodes.clear()
2080+
2081+
nodes = tqdm.tqdm(
2082+
ts.nodes(),
2083+
total=ts.num_nodes,
2084+
desc="Trim node metadata",
2085+
disable=not show_progress,
2086+
)
2087+
for node in nodes:
2088+
md = node.metadata
2089+
if node.is_sample():
2090+
# Note it would be nice to trim down the name of the pango field here
2091+
# but it's too tedious to test.
2092+
md = {k: md[k] for k in ["strain", "date", "Viridian_pangolin"]}
2093+
tables.nodes.append(node.replace(metadata=md))
2094+
return tables.tree_sequence()
2095+
2096+
2097+
def find_reversions(ts):
2098+
"""
2099+
Return a boolean array with True for all mutations in which the
2100+
inherited_state of the parent is equal to the derived_state of the
2101+
child.
2102+
"""
2103+
tables = ts.tables
2104+
assert np.all(
2105+
tables.mutations.derived_state_offset == np.arange(ts.num_mutations + 1)
2106+
)
2107+
derived_state = tables.mutations.derived_state.view("S1").astype(str)
2108+
assert np.all(tables.sites.ancestral_state_offset == np.arange(ts.num_sites + 1))
2109+
ancestral_state = tables.sites.ancestral_state.view("S1").astype(str)
2110+
del tables
2111+
inherited_state = ancestral_state[ts.mutations_site]
2112+
mutations_with_parent = ts.mutations_parent != -1
2113+
parent = ts.mutations_parent[mutations_with_parent]
2114+
assert np.all(parent >= 0)
2115+
inherited_state[mutations_with_parent] = derived_state[parent]
2116+
2117+
assert np.all(inherited_state != derived_state)
2118+
2119+
is_reversion = np.zeros(ts.num_mutations, dtype=bool)
2120+
is_reversion[mutations_with_parent] = (
2121+
derived_state[mutations_with_parent] == inherited_state[parent]
2122+
)
2123+
return is_reversion
2124+
2125+
2126+
def push_up_unary_recombinant_mutations(ts):
2127+
"""
2128+
Find any mutations that occur on unary children of a recombinant node,
2129+
and push those mutations onto the recombinant node itself. The
2130+
rationale for this is that, due to technical details of tree building,
2131+
we sometimes get a single child of a recombinant node, which can have
2132+
a large number of mutations. It is more parsimonious to assume that the
2133+
mutations occured on the branch(es) *leading to* the recombinant than
2134+
to have succeeded it.
2135+
"""
2136+
recomb_parent_edges = np.where(
2137+
ts.nodes_flags[ts.edges_parent] & core.NODE_IS_RECOMBINANT > 0
2138+
)[0]
2139+
by_parent = collections.defaultdict(list)
2140+
logger.info(f"Found {len(recomb_parent_edges)} edges with recombinant parent")
2141+
for e in recomb_parent_edges:
2142+
edge = ts.edge(e)
2143+
if edge.left == 0 and edge.right == ts.sequence_length:
2144+
by_parent[edge.parent].append(edge)
2145+
2146+
# We're only interested in full-span edges with a single child.
2147+
child_to_parent = {
2148+
e[0].child: e[0].parent for e in by_parent.values() if len(e) == 1
2149+
}
2150+
logger.info(f"Of which {len(child_to_parent)} are unary")
2151+
mutations_to_move = np.isin(
2152+
ts.mutations_node, np.array(list(child_to_parent.keys()), dtype=np.int32)
2153+
)
2154+
tables = ts.dump_tables()
2155+
for m in np.where(mutations_to_move)[0]:
2156+
row = tables.mutations[m]
2157+
node = child_to_parent[row.node]
2158+
# We're only changing the node and time, which are fixed size so we
2159+
# don't rewrite the table for each of these.
2160+
tables.mutations[m] = row.replace(node=node, time=ts.nodes_time[node])
2161+
logger.info(f"Moved up {np.sum(mutations_to_move)} mutations")
2162+
tables.sort()
2163+
return tables.tree_sequence()

tests/conftest.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,19 @@ def recombinant_example_3(tmp_path, fx_ts_map, fx_dataset, ds_path):
364364
return rts
365365

366366

367+
def recombinant_example_4(tmp_path, fx_recombinant_example_2):
368+
"""
369+
Same as recombinant_ex2 but with two mutations below the recombinant.
370+
"""
371+
tables = fx_recombinant_example_2.dump_tables()
372+
u = 55
373+
tables.mutations.add_row(
374+
site=2500, derived_state="G", node=u, time=tables.nodes.time[u], metadata={}
375+
)
376+
tables.sort()
377+
return tables.tree_sequence()
378+
379+
367380
@pytest.fixture
368381
def fx_recombinant_example_1(tmp_path, fx_data_cache, fx_ts_map, fx_dataset):
369382
cache_path = fx_data_cache / "recombinant_ex1.ts"
@@ -395,3 +408,13 @@ def fx_recombinant_example_3(tmp_path, fx_data_cache, fx_ts_map, fx_dataset):
395408
ts = recombinant_example_3(tmp_path, fx_ts_map, fx_dataset, ds_cache_path)
396409
ts.dump(cache_path)
397410
return tskit.load(cache_path)
411+
412+
413+
@pytest.fixture
414+
def fx_recombinant_example_4(tmp_path, fx_data_cache, fx_recombinant_example_2):
415+
cache_path = fx_data_cache / "recombinant_ex4.ts"
416+
if not cache_path.exists():
417+
print(f"Generating {cache_path}")
418+
ts = recombinant_example_4(tmp_path, fx_recombinant_example_2)
419+
ts.dump(cache_path)
420+
return tskit.load(cache_path)

tests/test_cli.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,22 @@ def test_multiple_override(self, tmp_path, fx_ts_map, fx_dataset):
376376
assert ts.num_samples == 0
377377

378378

379+
class TestPostprocess:
380+
381+
def test_example(self, tmp_path, fx_ts_map, fx_match_db):
382+
ts = fx_ts_map["2020-02-13"]
383+
out_ts_path = tmp_path / "ts.ts"
384+
runner = ct.CliRunner(mix_stderr=False)
385+
result = runner.invoke(
386+
cli.cli,
387+
f"postprocess {ts.path} {out_ts_path} --match-db={fx_match_db.path}",
388+
catch_exceptions=False,
389+
)
390+
assert result.exit_code == 0
391+
out = tskit.load(out_ts_path)
392+
assert out.num_samples == ts.num_samples + 8
393+
394+
379395
class TestValidate:
380396

381397
@pytest.mark.parametrize("date", ["2020-01-01", "2020-02-11"])

tests/test_inference.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@ def test_get_recombinant_strains_ex2(self, fx_recombinant_example_2):
7171
d = sc2ts.get_recombinant_strains(fx_recombinant_example_2)
7272
assert d == {56: ["recombinant_114:29825"]}
7373

74+
def test_get_recombinant_strains_ex4(self, fx_recombinant_example_4):
75+
d = sc2ts.get_recombinant_strains(fx_recombinant_example_4)
76+
assert d == {56: ["recombinant_114:29825"]}
77+
7478
def test_recombinant_example_1(self, fx_recombinant_example_1):
7579
ts = fx_recombinant_example_1
7680
samples_strain = ts.metadata["sc2ts"]["samples_strain"]
@@ -1412,3 +1416,101 @@ def test_validate(self, fx_ts_map, fx_dataset):
14121416
ts = fx_ts_map["2020-02-13"]
14131417
new_ts = sc2ts.map_deletions(ts, fx_dataset, frequency_threshold=0.001)
14141418
sc2ts.validate(new_ts, fx_dataset, deletions_as_missing=False)
1419+
1420+
1421+
@pytest.fixture
1422+
def fx_ts_exact_matches(fx_ts_map, fx_match_db):
1423+
ts = fx_ts_map["2020-02-13"]
1424+
tsp = sc2ts.append_exact_matches(ts, fx_match_db)
1425+
return tsp
1426+
1427+
1428+
class TestAppendExactMatches:
1429+
def test_validate(self, fx_ts_exact_matches, fx_dataset):
1430+
sc2ts.validate(fx_ts_exact_matches, fx_dataset)
1431+
1432+
def test_example_properties(self, fx_ts_exact_matches):
1433+
ts = fx_ts_exact_matches
1434+
samples_strain = ts.metadata["sc2ts"]["samples_strain"]
1435+
assert [ts.node(u).metadata["strain"] for u in ts.samples()] == samples_strain
1436+
assert ts.num_nodes == 61
1437+
tree = ts.first()
1438+
assert tree.num_roots == 1
1439+
1440+
def test_times_agree(self, fx_ts_exact_matches):
1441+
ts = fx_ts_exact_matches
1442+
date_to_time = {}
1443+
time_to_date = {}
1444+
for u in ts.samples():
1445+
node = ts.node(u)
1446+
time = node.time
1447+
date = node.metadata["date"]
1448+
if date not in date_to_time:
1449+
date_to_time[date] = time
1450+
assert date_to_time[date] == time
1451+
if time not in time_to_date:
1452+
time_to_date[time] = date
1453+
assert time_to_date[time] == date
1454+
1455+
def test_flags(self, fx_ts_exact_matches):
1456+
ts = fx_ts_exact_matches
1457+
assert np.all(
1458+
ts.nodes_flags[-8:] == sc2ts.NODE_IS_EXACT_MATCH | tskit.NODE_IS_SAMPLE
1459+
)
1460+
1461+
def test_exact_match_counts(self, fx_ts_exact_matches):
1462+
ts = fx_ts_exact_matches
1463+
tree = ts.first()
1464+
node_count = ts.metadata["sc2ts"]["cumulative_stats"]["exact_matches"]["node"]
1465+
for u in tree.nodes():
1466+
num_exact_matches = 0
1467+
for v in tree.children(u):
1468+
if (ts.nodes_flags[v] & sc2ts.NODE_IS_EXACT_MATCH) > 0:
1469+
num_exact_matches += 1
1470+
assert node_count.get(str(u), 0) == num_exact_matches
1471+
1472+
1473+
class TestTrimMetadata:
1474+
def test_validate(self, fx_ts_map, fx_dataset):
1475+
ts = fx_ts_map["2020-02-13"]
1476+
tsp = sc2ts.trim_metadata(ts)
1477+
sc2ts.validate(tsp, fx_dataset)
1478+
1479+
def test_fields(self, fx_ts_map):
1480+
ts = fx_ts_map["2020-02-13"]
1481+
tsp = sc2ts.trim_metadata(ts)
1482+
for u in tsp.samples():
1483+
node = tsp.node(u)
1484+
assert set(node.metadata.keys()) == {"strain", "date", "Viridian_pangolin"}
1485+
1486+
1487+
class TestPushUpRecombinantMutations:
1488+
1489+
def test_no_recombinants(self, fx_ts_map):
1490+
ts = fx_ts_map["2020-02-13"]
1491+
tsp = sc2ts.push_up_unary_recombinant_mutations(ts)
1492+
ts.tables.assert_equals(tsp.tables)
1493+
1494+
def test_recombinant_example_1(self, fx_recombinant_example_1):
1495+
ts = fx_recombinant_example_1
1496+
tsp = sc2ts.push_up_unary_recombinant_mutations(ts)
1497+
ts.tables.assert_equals(tsp.tables)
1498+
1499+
def test_recombinant_example_2(self, fx_recombinant_example_2):
1500+
ts = fx_recombinant_example_2
1501+
tsp = sc2ts.push_up_unary_recombinant_mutations(ts)
1502+
ts.tables.assert_equals(tsp.tables)
1503+
1504+
def test_recombinant_example_3(self, fx_recombinant_example_3):
1505+
ts = fx_recombinant_example_3
1506+
tsp = sc2ts.push_up_unary_recombinant_mutations(ts)
1507+
ts.tables.assert_equals(tsp.tables)
1508+
1509+
def test_recombinant_example_4(self, fx_recombinant_example_4):
1510+
ts = fx_recombinant_example_4
1511+
site = 2500
1512+
mut = ts.site(site).mutations[0]
1513+
assert mut.node == 55
1514+
tsp = sc2ts.push_up_unary_recombinant_mutations(ts)
1515+
mut = tsp.site(site).mutations[0]
1516+
assert mut.node == 56

0 commit comments

Comments
 (0)