@@ -402,7 +402,7 @@ def get_fixed(self, arr, edge):
402402 return arr * liks
403403
404404 def scale_geometric (self , fraction , value ):
405- return value ** fraction
405+ return value ** fraction
406406
407407
408408class LogLikelihoods (Likelihoods ):
@@ -647,11 +647,22 @@ def inside_pass(self, *, normalize=True, cache_inside=False, progress=None):
647647 inside = self .priors .clone_with_new_data ( # store inside matrix values
648648 grid_data = np .nan , fixed_data = self .lik .identity_constant
649649 )
650+ # It is possible that a simple node is non-fixed, in which case we want to
651+ # provide an inside array that reflects the prior distribution
652+ nonfixed_samples = np .intersect1d (inside .nonfixed_node_ids (), self .ts .samples ())
653+ for u in nonfixed_samples :
654+ # this is in the same probability space as the prior, so we should be
655+ # OK just to copy the prior values straight in. It's unclear to me (Yan)
656+ # how/if they should be normalised, however
657+ inside [u ][:] = self .priors [u ]
658+
650659 if cache_inside :
651660 g_i = np .full (
652661 (self .ts .num_edges , self .lik .grid_size ), self .lik .identity_constant
653662 )
654663 norm = np .full (self .ts .num_nodes , np .nan )
664+ to_visit = np .zeros (self .ts .num_nodes , dtype = bool )
665+ to_visit [inside .nonfixed_node_ids ()] = True
655666 # Iterate through the nodes via groupby on parent node
656667 for parent , edges in tqdm (
657668 self .edges_by_parent_asc (),
@@ -686,16 +697,22 @@ def inside_pass(self, *, normalize=True, cache_inside=False, progress=None):
686697 "dangling nodes: please simplify it"
687698 )
688699 daughter_val = self .lik .scale_geometric (
689- spanfrac , self .lik .make_lower_tri (inside [ edge . child ] )
700+ spanfrac , self .lik .make_lower_tri (inside_values )
690701 )
691702 edge_lik = self .lik .get_inside (daughter_val , edge )
692703 val = self .lik .combine (val , edge_lik )
693704 if np .all (val == 0 ):
694705 raise ValueError
695706 if cache_inside :
696707 g_i [edge .id ] = edge_lik
697- norm [parent ] = np .max (val ) if normalize else 1
708+ norm [parent ] = np .max (val ) if normalize else self . lik . identity_constant
698709 inside [parent ] = self .lik .reduce (val , norm [parent ])
710+ to_visit [parent ] = False
711+
712+ # There may be nodes that are not parents but are also not fixed (e.g.
713+ # undated sample nodes). These need an identity normalization constant
714+ for unfixed_unvisited in np .where (to_visit )[0 ]:
715+ norm [unfixed_unvisited ] = self .lik .identity_constant
699716
700717 if cache_inside :
701718 self .g_i = self .lik .reduce (g_i , norm [self .ts .tables .edges .child , None ])
@@ -913,34 +930,32 @@ def posterior_mean_var(ts, posterior, *, fixed_node_set=None):
913930 return ts , mn_post , vr_post
914931
915932
916- def constrain_ages_topo (ts , post_mn , eps , nodes_to_date = None , progress = False ):
933+ def constrain_ages_topo (ts , node_times , eps , progress = False ):
917934 """
918- If predicted node times violate topology, restrict node ages so that they
919- must be older than all their children.
935+ If node_times violate topology, return increased node_times so that each node is
936+ guaranteed to be older than any of its their children.
920937 """
921- new_mn_post = np .copy (post_mn )
922- if nodes_to_date is None :
923- nodes_to_date = np .arange (ts .num_nodes , dtype = np .uint64 )
924- nodes_to_date = nodes_to_date [~ np .isin (nodes_to_date , ts .samples ())]
925-
926- tables = ts .tables
927- parents = tables .edges .parent
928- nd_children = tables .edges .child [np .argsort (parents )]
929- parents = sorted (parents )
930- parents_unique = np .unique (parents , return_index = True )
931- parent_indices = parents_unique [1 ][np .isin (parents_unique [0 ], nodes_to_date )]
932- for index , nd in tqdm (
933- enumerate (sorted (nodes_to_date )), desc = "Constrain Ages" , disable = not progress
938+ edges_parent = ts .edges_parent
939+ edges_child = ts .edges_child
940+
941+ new_node_times = np .copy (node_times )
942+ # Traverse through the ARG, ensuring children come before parents.
943+ # This can be done by iterating over groups of edges with the same parent
944+ new_parent_edge_idx = np .where (np .diff (edges_parent ) != 0 )[0 ] + 1
945+ for edges_start , edges_end in tqdm (
946+ zip (
947+ itertools .chain ([0 ], new_parent_edge_idx ),
948+ itertools .chain (new_parent_edge_idx , [len (edges_parent )]),
949+ ),
950+ desc = "Constrain Ages" ,
951+ disable = not progress ,
934952 ):
935- if index + 1 != len (nodes_to_date ):
936- children_index = np .arange (parent_indices [index ], parent_indices [index + 1 ])
937- else :
938- children_index = np .arange (parent_indices [index ], ts .num_edges )
939- children = nd_children [children_index ]
940- time = np .max (new_mn_post [children ])
941- if new_mn_post [nd ] <= time :
942- new_mn_post [nd ] = time + eps
943- return new_mn_post
953+ parent = edges_parent [edges_start ]
954+ child_ids = edges_child [edges_start :edges_end ] # May contain dups
955+ oldest_child_time = np .max (new_node_times [child_ids ])
956+ if oldest_child_time >= new_node_times [parent ]:
957+ new_node_times [parent ] = oldest_child_time + eps
958+ return new_node_times
944959
945960
946961def date (
@@ -1031,7 +1046,7 @@ def date(
10311046 progress = progress ,
10321047 ** kwargs
10331048 )
1034- constrained = constrain_ages_topo (tree_sequence , dates , eps , nds , progress )
1049+ constrained = constrain_ages_topo (tree_sequence , dates , eps , progress )
10351050 tables = tree_sequence .dump_tables ()
10361051 tables .time_units = time_units
10371052 tables .nodes .time = constrained
@@ -1080,8 +1095,6 @@ def get_dates(
10801095
10811096 :return: tuple(mn_post, posterior, timepoints, eps, nodes_to_date)
10821097 """
1083- fixed_nodes = set (tree_sequence .samples ())
1084-
10851098 # Default to not creating approximate priors unless ts has > 1000 samples
10861099 approx_priors = False
10871100 if tree_sequence .num_samples > 1000 :
@@ -1109,6 +1122,8 @@ def get_dates(
11091122 )
11101123 priors = priors
11111124
1125+ fixed_nodes = set (priors .fixed_node_ids ())
1126+
11121127 if probability_space != base .LOG :
11131128 liklhd = Likelihoods (
11141129 tree_sequence ,
0 commit comments