@@ -814,7 +814,7 @@ def update_top_level_metadata(ts, date, retro_groups, samples):
814814 return tables
815815
816816
817- def add_sample_to_tables (sample , tables , group_id = None ):
817+ def add_sample_to_tables (sample , tables , group_id = None , time = 0 ):
818818 sc2ts_md = {
819819 "hmm_match" : sample .hmm_match .asdict (),
820820 "alignment_composition" : dict (sample .alignment_composition ),
@@ -825,7 +825,7 @@ def add_sample_to_tables(sample, tables, group_id=None):
825825 if group_id is not None :
826826 sc2ts_md ["group_id" ] = group_id
827827 metadata = {** sample .metadata , "sc2ts" : sc2ts_md }
828- return tables .nodes .add_row (flags = sample .flags , metadata = metadata )
828+ return tables .nodes .add_row (flags = sample .flags , metadata = metadata , time = time )
829829
830830
831831def match_path_ts (group ):
@@ -2031,3 +2031,133 @@ def map_deletions(ts, ds, *, frequency_threshold, show_progress=False):
20312031 tables .build_index ()
20322032 tables .compute_mutation_parents ()
20332033 return tables .tree_sequence ()
2034+
2035+
2036+ def append_exact_matches (ts , match_db , show_progress = False ):
2037+ """
2038+ Update the specified tree sequence to include all exact matches
2039+ from the specified match DB.
2040+ """
2041+ md = ts .metadata
2042+ date = md ["sc2ts" ]["date" ]
2043+ total_exact_matches = sum (
2044+ md ["sc2ts" ]["cumulative_stats" ]["exact_matches" ]["pango" ].values ()
2045+ )
2046+ samples_strain = md ["sc2ts" ]["samples_strain" ]
2047+ tables = ts .dump_tables ()
2048+ L = tables .sequence_length
2049+ time_zero = parse_date (date )
2050+ with match_db .conn :
2051+ sql = f"SELECT * FROM samples WHERE hmm_cost == 0 AND match_date <= '{ date } '"
2052+ rows = tqdm .tqdm (
2053+ match_db .conn .execute (sql ),
2054+ total = total_exact_matches ,
2055+ desc = "Exact matches" ,
2056+ disable = not show_progress ,
2057+ )
2058+ for row in rows :
2059+ pkl = row .pop ("pickle" )
2060+ sample = pickle .loads (bz2 .decompress (pkl ))
2061+ sample .flags |= core .NODE_IS_EXACT_MATCH
2062+ delta = time_zero - parse_date (sample .date )
2063+ assert delta .days >= 0
2064+ u = add_sample_to_tables (sample , tables , time = delta .days )
2065+ parent = sample .hmm_match .path [0 ].parent
2066+ tables .edges .add_row (0 , L , parent = parent , child = u )
2067+ samples_strain .append (sample .strain )
2068+
2069+ assert total_exact_matches == len (tables .nodes ) - ts .num_nodes
2070+ md ["sc2ts" ]["samples_strain" ] = samples_strain
2071+ tables .metadata = md
2072+ tables .sort ()
2073+ return tables .tree_sequence ()
2074+
2075+
2076+ def trim_metadata (ts , show_progress = False ):
2077+ tables = ts .dump_tables ()
2078+
2079+ tables .nodes .clear ()
2080+
2081+ nodes = tqdm .tqdm (
2082+ ts .nodes (),
2083+ total = ts .num_nodes ,
2084+ desc = "Trim node metadata" ,
2085+ disable = not show_progress ,
2086+ )
2087+ for node in nodes :
2088+ md = node .metadata
2089+ if node .is_sample ():
2090+ # Note it would be nice to trim down the name of the pango field here
2091+ # but it's too tedious to test.
2092+ md = {k : md [k ] for k in ["strain" , "date" , "Viridian_pangolin" ]}
2093+ tables .nodes .append (node .replace (metadata = md ))
2094+ return tables .tree_sequence ()
2095+
2096+
2097+ def find_reversions (ts ):
2098+ """
2099+ Return a boolean array with True for all mutations in which the
2100+ inherited_state of the parent is equal to the derived_state of the
2101+ child.
2102+ """
2103+ tables = ts .tables
2104+ assert np .all (
2105+ tables .mutations .derived_state_offset == np .arange (ts .num_mutations + 1 )
2106+ )
2107+ derived_state = tables .mutations .derived_state .view ("S1" ).astype (str )
2108+ assert np .all (tables .sites .ancestral_state_offset == np .arange (ts .num_sites + 1 ))
2109+ ancestral_state = tables .sites .ancestral_state .view ("S1" ).astype (str )
2110+ del tables
2111+ inherited_state = ancestral_state [ts .mutations_site ]
2112+ mutations_with_parent = ts .mutations_parent != - 1
2113+ parent = ts .mutations_parent [mutations_with_parent ]
2114+ assert np .all (parent >= 0 )
2115+ inherited_state [mutations_with_parent ] = derived_state [parent ]
2116+
2117+ assert np .all (inherited_state != derived_state )
2118+
2119+ is_reversion = np .zeros (ts .num_mutations , dtype = bool )
2120+ is_reversion [mutations_with_parent ] = (
2121+ derived_state [mutations_with_parent ] == inherited_state [parent ]
2122+ )
2123+ return is_reversion
2124+
2125+
2126+ def push_up_unary_recombinant_mutations (ts ):
2127+ """
2128+ Find any mutations that occur on unary children of a recombinant node,
2129+ and push those mutations onto the recombinant node itself. The
2130+ rationale for this is that, due to technical details of tree building,
2131+ we sometimes get a single child of a recombinant node, which can have
2132+ a large number of mutations. It is more parsimonious to assume that the
2133+ mutations occured on the branch(es) *leading to* the recombinant than
2134+ to have succeeded it.
2135+ """
2136+ recomb_parent_edges = np .where (
2137+ ts .nodes_flags [ts .edges_parent ] & core .NODE_IS_RECOMBINANT > 0
2138+ )[0 ]
2139+ by_parent = collections .defaultdict (list )
2140+ logger .info (f"Found { len (recomb_parent_edges )} edges with recombinant parent" )
2141+ for e in recomb_parent_edges :
2142+ edge = ts .edge (e )
2143+ if edge .left == 0 and edge .right == ts .sequence_length :
2144+ by_parent [edge .parent ].append (edge )
2145+
2146+ # We're only interested in full-span edges with a single child.
2147+ child_to_parent = {
2148+ e [0 ].child : e [0 ].parent for e in by_parent .values () if len (e ) == 1
2149+ }
2150+ logger .info (f"Of which { len (child_to_parent )} are unary" )
2151+ mutations_to_move = np .isin (
2152+ ts .mutations_node , np .array (list (child_to_parent .keys ()), dtype = np .int32 )
2153+ )
2154+ tables = ts .dump_tables ()
2155+ for m in np .where (mutations_to_move )[0 ]:
2156+ row = tables .mutations [m ]
2157+ node = child_to_parent [row .node ]
2158+ # We're only changing the node and time, which are fixed size so we
2159+ # don't rewrite the table for each of these.
2160+ tables .mutations [m ] = row .replace (node = node , time = ts .nodes_time [node ])
2161+ logger .info (f"Moved up { np .sum (mutations_to_move )} mutations" )
2162+ tables .sort ()
2163+ return tables .tree_sequence ()
0 commit comments