@@ -240,7 +240,6 @@ def node_mutations(self):
240240 muts [site .position ] = f"{ state0 } >{ state1 } "
241241 return muts
242242
243-
244243 def __init__ (
245244 self ,
246245 ts ,
@@ -258,10 +257,13 @@ def __init__(
258257 self .node = node
259258 if edges is None : # the required edge table wasn't given, so recalculate
260259 edges = tskit .EdgeTable ()
261- for e in sorted ([ts .edge (i ) for i in np .where (ts .edges_child == node )[0 ]], key = lambda e : e .left ):
260+ for e in sorted (
261+ [ts .edge (i ) for i in np .where (ts .edges_child == node )[0 ]],
262+ key = lambda e : e .left ,
263+ ):
262264 edges .append (e )
263265 self .edges = edges
264-
266+
265267 def html (
266268 self ,
267269 show_bases = True ,
@@ -278,7 +280,7 @@ def html(
278280 using the ``IPython.display.HTML`` function.
279281
280282 :param ts TreeSequence:
281- The tree sequence to which the nodes refer
283+ The tree sequence to which the nodes refer
282284 :param node int:
283285 The node ID of the child node, usually a recombination node.
284286 This will be placed on the second row of the copying pattern, so that
@@ -311,6 +313,7 @@ def html(
311313 document (e.g. a Jupyter notebook) that already has one copying table shown with
312314 the standard stylesheet. If False or None (default), include the default stylesheet.
313315 """
316+
314317 def row_lab (txt ):
315318 return "" if hide_labels else f"<th>{ txt } </th>"
316319
@@ -448,11 +451,12 @@ def __init__(
448451 quick = False ,
449452 show_progress = True ,
450453 pango_source = "Viridian_pangolin" ,
454+ scorpio_source = "Viridian_scorpio" ,
451455 sample_group_id_prefix_len = 10 ,
452456 ):
453457 self .ts = ts
454458 self .pango_source = pango_source
455- self .scorpio_source = "Viridian_scorpio"
459+ self .scorpio_source = scorpio_source
456460 self .strain_map = {}
457461 self .recombinants = np .where (ts .nodes_flags == core .NODE_IS_RECOMBINANT )[0 ]
458462
@@ -967,11 +971,25 @@ def recombinants_summary(
967971 ):
968972 if parent_pango_source is None :
969973 parent_pango_source = self .pango_source
974+
975+ def node_info (node , label ):
976+ datum = {label : node }
977+ datum [f"{ label } _pango" ] = self .nodes_metadata [node ].get (
978+ self .pango_source , "Unknown"
979+ )
980+ datum [f"{ label } _scorpio" ] = self .nodes_metadata [node ].get (
981+ self .scorpio_source , "Unknown"
982+ )
983+ datum [f"{ label } _time" ] = self .ts .nodes_time [node ]
984+ datum [f"{ label } _date" ] = self .nodes_date [node ]
985+ return datum
986+
970987 data = []
971988 for u in self .recombinants :
972989 md = dict (self .nodes_metadata [u ]["sc2ts" ])
973990 group_id = md ["group_id" ][: self .sample_group_id_prefix_len ]
974991 md ["group_id" ] = group_id
992+
975993 group_nodes = self .sample_group_nodes [group_id ]
976994 md ["group_size" ] = len (group_nodes )
977995
@@ -983,13 +1001,17 @@ def recombinants_summary(
9831001 causal_lineages = {}
9841002 hmm_matches = []
9851003 breakpoint_intervals = []
1004+ copying_path_mutations = collections .defaultdict (list )
9861005 for v in samples :
987- causal_lineages [v ] = self .nodes_metadata [v ].get (
988- self .pango_source , "Unknown"
989- )
990-
991- # Arbitrarily pick the first sample node as the representative
992- v = samples [0 ]
1006+ sample_md = self .nodes_metadata [v ]
1007+ causal_lineages [v ] = sample_md .get (self .pango_source , "Unknown" )
1008+ hmm_mutations = len (sample_md ["sc2ts" ]["hmm_match" ]["mutations" ])
1009+ copying_path_mutations [hmm_mutations ].append (v )
1010+
1011+ min_mutations = min (copying_path_mutations .keys ())
1012+ # Choose our representative sample as one of the ones that have the
1013+ # fewest mutations in it's copying path.
1014+ v = copying_path_mutations [min_mutations ][0 ]
9931015 node_md = self .nodes_metadata [v ]["sc2ts" ]
9941016 hmm_matches .append (node_md ["hmm_match" ])
9951017 breakpoint_intervals .append (node_md ["breakpoint_intervals" ])
@@ -1003,30 +1025,33 @@ def recombinants_summary(
10031025 interval = breakpoint_intervals [0 ]
10041026 parent_left = hmm_match ["path" ][0 ]["parent" ]
10051027 parent_right = hmm_match ["path" ][1 ]["parent" ]
1006- data .append (
1007- {
1008- "recombinant" : u ,
1009- "descendants" : self .nodes_max_descendant_samples [u ],
1010- "sample" : v ,
1011- "sample_pango" : causal_lineages [v ],
1012- "num_samples" : len (samples ),
1013- "distinct_sample_pango" : len (set (causal_lineages .values ())),
1014- "interval_left" : interval [0 ][0 ],
1015- "interval_right" : interval [0 ][1 ],
1016- "parent_left" : parent_left ,
1017- "parent_right" : parent_right ,
1018- "parent_left_pango" : self .nodes_metadata [parent_left ].get (
1019- parent_pango_source ,
1020- "Unknown" ,
1021- ),
1022- "parent_right_pango" : self .nodes_metadata [parent_right ].get (
1023- parent_pango_source ,
1024- "Unknown" ,
1025- ),
1026- "num_mutations" : len (hmm_match ["mutations" ]),
1027- ** md ,
1028- }
1029- )
1028+
1029+ datum = {
1030+ "num_descendant_samples" : self .nodes_max_descendant_samples [u ],
1031+ "num_samples" : len (samples ),
1032+ "distinct_sample_pango" : len (set (causal_lineages .values ())),
1033+ "interval_left" : interval [0 ][0 ],
1034+ "interval_right" : interval [0 ][1 ],
1035+ "num_mutations" : len (hmm_match ["mutations" ]),
1036+ "Viridian_amplicon_scheme" : self .nodes_metadata [v ].get (
1037+ "Viridian_amplicon_scheme" , "Unknown"
1038+ ),
1039+ "Artic_primer_version" : self .nodes_metadata [v ].get (
1040+ "Artic_primer_version" , "Unknown"
1041+ ),
1042+ ** md ,
1043+ }
1044+
1045+ for node , label in [
1046+ (u , "recombinant" ),
1047+ (v , "sample" ),
1048+ (parent_left , "parent_left" ),
1049+ (parent_right , "parent_right" ),
1050+ ]:
1051+ datum = {** datum , ** node_info (node , label )}
1052+
1053+ data .append (datum )
1054+
10301055 # Compute the MRCAs by iterating along trees in order of
10311056 # breakpoint. We use the right interval
10321057 df = pd .DataFrame (data ).sort_values ("interval_right" )
@@ -1043,10 +1068,11 @@ def recombinants_summary(
10431068 left_path = jit .get_root_path (tree , row .parent_left )
10441069 assert tree .parent (row .recombinant ) == row .parent_left
10451070 mrca = jit .get_path_mrca (left_path , right_path , self .ts .nodes_time )
1046- mrca_data .append (mrca )
1047- mrca_data = np .array (mrca_data )
1048- df ["mrca" ] = mrca_data
1049- df ["t_mrca" ] = self .ts .nodes_time [mrca_data ]
1071+ mrca_data .append (node_info (mrca , "parent_mrca" ))
1072+
1073+ mrca_df = pd .DataFrame (mrca_data )
1074+ for col in mrca_df :
1075+ df [col ] = mrca_df [col ]
10501076
10511077 if characterise_copying :
10521078 # Slow - don't do this unless we really want to.
0 commit comments