Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 65 additions & 39 deletions sc2ts/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,6 @@ def node_mutations(self):
muts[site.position] = f"{state0}>{state1}"
return muts


def __init__(
self,
ts,
Expand All @@ -258,10 +257,13 @@ def __init__(
self.node = node
if edges is None: # the required edge table wasn't given, so recalculate
edges = tskit.EdgeTable()
for e in sorted([ts.edge(i) for i in np.where(ts.edges_child==node)[0]], key=lambda e: e.left):
for e in sorted(
[ts.edge(i) for i in np.where(ts.edges_child == node)[0]],
key=lambda e: e.left,
):
edges.append(e)
self.edges = edges

def html(
self,
show_bases=True,
Expand All @@ -278,7 +280,7 @@ def html(
using the ``IPython.display.HTML`` function.

:param ts TreeSequence:
The tree sequence to which the nodes refer
The tree sequence to which the nodes refer
:param node int:
The node ID of the child node, usually a recombination node.
This will be placed on the second row of the copying pattern, so that
Expand Down Expand Up @@ -311,6 +313,7 @@ def html(
document (e.g. a Jupyter notebook) that already has one copying table shown with
the standard stylesheet. If False or None (default), include the default stylesheet.
"""

def row_lab(txt):
return "" if hide_labels else f"<th>{txt}</th>"

Expand Down Expand Up @@ -448,11 +451,12 @@ def __init__(
quick=False,
show_progress=True,
pango_source="Viridian_pangolin",
scorpio_source="Viridian_scorpio",
sample_group_id_prefix_len=10,
):
self.ts = ts
self.pango_source = pango_source
self.scorpio_source = "Viridian_scorpio"
self.scorpio_source = scorpio_source
self.strain_map = {}
self.recombinants = np.where(ts.nodes_flags == core.NODE_IS_RECOMBINANT)[0]

Expand Down Expand Up @@ -967,11 +971,25 @@ def recombinants_summary(
):
if parent_pango_source is None:
parent_pango_source = self.pango_source

def node_info(node, label):
datum = {label: node}
datum[f"{label}_pango"] = self.nodes_metadata[node].get(
self.pango_source, "Unknown"
)
datum[f"{label}_scorpio"] = self.nodes_metadata[node].get(
self.scorpio_source, "Unknown"
)
datum[f"{label}_time"] = self.ts.nodes_time[node]
datum[f"{label}_date"] = self.nodes_date[node]
return datum

data = []
for u in self.recombinants:
md = dict(self.nodes_metadata[u]["sc2ts"])
group_id = md["group_id"][: self.sample_group_id_prefix_len]
md["group_id"] = group_id

group_nodes = self.sample_group_nodes[group_id]
md["group_size"] = len(group_nodes)

Expand All @@ -983,13 +1001,17 @@ def recombinants_summary(
causal_lineages = {}
hmm_matches = []
breakpoint_intervals = []
copying_path_mutations = collections.defaultdict(list)
for v in samples:
causal_lineages[v] = self.nodes_metadata[v].get(
self.pango_source, "Unknown"
)

# Arbitrarily pick the first sample node as the representative
v = samples[0]
sample_md = self.nodes_metadata[v]
causal_lineages[v] = sample_md.get(self.pango_source, "Unknown")
hmm_mutations = len(sample_md["sc2ts"]["hmm_match"]["mutations"])
copying_path_mutations[hmm_mutations].append(v)

min_mutations = min(copying_path_mutations.keys())
# Choose our representative sample as one of the ones that have the
# fewest mutations in it's copying path.
v = copying_path_mutations[min_mutations][0]
node_md = self.nodes_metadata[v]["sc2ts"]
hmm_matches.append(node_md["hmm_match"])
breakpoint_intervals.append(node_md["breakpoint_intervals"])
Expand All @@ -1003,30 +1025,33 @@ def recombinants_summary(
interval = breakpoint_intervals[0]
parent_left = hmm_match["path"][0]["parent"]
parent_right = hmm_match["path"][1]["parent"]
data.append(
{
"recombinant": u,
"descendants": self.nodes_max_descendant_samples[u],
"sample": v,
"sample_pango": causal_lineages[v],
"num_samples": len(samples),
"distinct_sample_pango": len(set(causal_lineages.values())),
"interval_left": interval[0][0],
"interval_right": interval[0][1],
"parent_left": parent_left,
"parent_right": parent_right,
"parent_left_pango": self.nodes_metadata[parent_left].get(
parent_pango_source,
"Unknown",
),
"parent_right_pango": self.nodes_metadata[parent_right].get(
parent_pango_source,
"Unknown",
),
"num_mutations": len(hmm_match["mutations"]),
**md,
}
)

datum = {
"num_descendant_samples": self.nodes_max_descendant_samples[u],
"num_samples": len(samples),
"distinct_sample_pango": len(set(causal_lineages.values())),
"interval_left": interval[0][0],
"interval_right": interval[0][1],
"num_mutations": len(hmm_match["mutations"]),
"Viridian_amplicon_scheme": self.nodes_metadata[v].get(
"Viridian_amplicon_scheme", "Unknown"
),
"Artic_primer_version": self.nodes_metadata[v].get(
"Artic_primer_version", "Unknown"
),
**md,
}

for node, label in [
(u, "recombinant"),
(v, "sample"),
(parent_left, "parent_left"),
(parent_right, "parent_right"),
]:
datum = {**datum, **node_info(node, label)}

data.append(datum)

# Compute the MRCAs by iterating along trees in order of
# breakpoint. We use the right interval
df = pd.DataFrame(data).sort_values("interval_right")
Expand All @@ -1043,10 +1068,11 @@ def recombinants_summary(
left_path = jit.get_root_path(tree, row.parent_left)
assert tree.parent(row.recombinant) == row.parent_left
mrca = jit.get_path_mrca(left_path, right_path, self.ts.nodes_time)
mrca_data.append(mrca)
mrca_data = np.array(mrca_data)
df["mrca"] = mrca_data
df["t_mrca"] = self.ts.nodes_time[mrca_data]
mrca_data.append(node_info(mrca, "parent_mrca"))

mrca_df = pd.DataFrame(mrca_data)
for col in mrca_df:
df[col] = mrca_df[col]

if characterise_copying:
# Slow - don't do this unless we really want to.
Expand Down
14 changes: 7 additions & 7 deletions tests/test_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def test_recombinants_summary_example_1(self, fx_ti_recombinant_example_1):
df = fx_ti_recombinant_example_1.recombinants_summary()
assert df.shape[0] == 1
row = df.iloc[0]
assert row.descendants == 2
assert row.num_descendant_samples == 2
assert row["sample"] == 53
assert row.num_samples == 2
assert row.group_size == 3
Expand All @@ -189,8 +189,8 @@ def test_recombinants_summary_example_1(self, fx_ti_recombinant_example_1):
assert row.parent_right == 46
assert row.parent_right_pango == "Unknown"
assert row.num_mutations == 0
assert row.mrca == 1
assert row.t_mrca == 51
assert row.parent_mrca == 1
assert row.parent_mrca_time == 51
assert "diffs" not in df

df2 = fx_ti_recombinant_example_1.recombinants_summary(
Expand All @@ -206,17 +206,19 @@ def test_recombinants_summary_example_2(self, fx_recombinant_example_2):
df = ti.recombinants_summary(characterise_copying=True, show_progress=False)
assert df.shape[0] == 1
row = df.iloc[0]
assert row.descendants == 1
assert row.num_descendant_samples == 1
assert row["sample"] == 55
assert row["distinct_sample_pango"] == 1
assert row["recombinant"] == 56
assert row["recombinant_pango"] == "Unknown"
assert row["recombinant_time"] == 0.000001
assert row["sample_pango"] == "Unknown"
assert row["num_mutations"] == 0
assert row["parent_left"] == 53
assert row["parent_left_pango"] == "Unknown"
assert row["parent_right"] == 54
assert row["parent_right_pango"] == "Unknown"
assert row["mrca"] == 48
assert row["parent_mrca"] == 48
assert row["group_size"] == 2
assert row["diffs"] == 6
assert row["max_run_length"] == 2
Expand All @@ -243,8 +245,6 @@ def test_example_node(self, fx_ts_min_2020_02_15, fx_ti_2020_02_15):
nt.assert_array_equal(
ti.nodes_max_descendant_samples, df["max_descendant_samples"]
)
print(ti.nodes_date.dtype)
print(df["date"].dtype)
nt.assert_array_equal(ti.nodes_date, df["date"])
assert list(np.where(df["is_recombinant"])[0]) == list(ti.recombinants)
assert list(np.where(df["is_sample"])[0]) == list(ts.samples())
Expand Down