@@ -402,7 +402,7 @@ def get_fixed(self, arr, edge):
402
402
return arr * liks
403
403
404
404
def scale_geometric (self , fraction , value ):
405
- return value ** fraction
405
+ return value ** fraction
406
406
407
407
408
408
class LogLikelihoods (Likelihoods ):
@@ -647,11 +647,22 @@ def inside_pass(self, *, normalize=True, cache_inside=False, progress=None):
647
647
inside = self .priors .clone_with_new_data ( # store inside matrix values
648
648
grid_data = np .nan , fixed_data = self .lik .identity_constant
649
649
)
650
+ # It is possible that a simple node is non-fixed, in which case we want to
651
+ # provide an inside array that reflects the prior distribution
652
+ nonfixed_samples = np .intersect1d (inside .nonfixed_node_ids (), self .ts .samples ())
653
+ for u in nonfixed_samples :
654
+ # this is in the same probability space as the prior, so we should be
655
+ # OK just to copy the prior values straight in. It's unclear to me (Yan)
656
+ # how/if they should be normalised, however
657
+ inside [u ][:] = self .priors [u ]
658
+
650
659
if cache_inside :
651
660
g_i = np .full (
652
661
(self .ts .num_edges , self .lik .grid_size ), self .lik .identity_constant
653
662
)
654
663
norm = np .full (self .ts .num_nodes , np .nan )
664
+ to_visit = np .zeros (self .ts .num_nodes , dtype = bool )
665
+ to_visit [inside .nonfixed_node_ids ()] = True
655
666
# Iterate through the nodes via groupby on parent node
656
667
for parent , edges in tqdm (
657
668
self .edges_by_parent_asc (),
@@ -686,16 +697,22 @@ def inside_pass(self, *, normalize=True, cache_inside=False, progress=None):
686
697
"dangling nodes: please simplify it"
687
698
)
688
699
daughter_val = self .lik .scale_geometric (
689
- spanfrac , self .lik .make_lower_tri (inside [ edge . child ] )
700
+ spanfrac , self .lik .make_lower_tri (inside_values )
690
701
)
691
702
edge_lik = self .lik .get_inside (daughter_val , edge )
692
703
val = self .lik .combine (val , edge_lik )
693
704
if np .all (val == 0 ):
694
705
raise ValueError
695
706
if cache_inside :
696
707
g_i [edge .id ] = edge_lik
697
- norm [parent ] = np .max (val ) if normalize else 1
708
+ norm [parent ] = np .max (val ) if normalize else self . lik . identity_constant
698
709
inside [parent ] = self .lik .reduce (val , norm [parent ])
710
+ to_visit [parent ] = False
711
+
712
+ # There may be nodes that are not parents but are also not fixed (e.g.
713
+ # undated sample nodes). These need an identity normalization constant
714
+ for unfixed_unvisited in np .where (to_visit )[0 ]:
715
+ norm [unfixed_unvisited ] = self .lik .identity_constant
699
716
700
717
if cache_inside :
701
718
self .g_i = self .lik .reduce (g_i , norm [self .ts .tables .edges .child , None ])
@@ -913,34 +930,32 @@ def posterior_mean_var(ts, posterior, *, fixed_node_set=None):
913
930
return ts , mn_post , vr_post
914
931
915
932
916
- def constrain_ages_topo (ts , post_mn , eps , nodes_to_date = None , progress = False ):
933
+ def constrain_ages_topo (ts , node_times , eps , progress = False ):
917
934
"""
918
- If predicted node times violate topology, restrict node ages so that they
919
- must be older than all their children.
935
+ If node_times violate topology, return increased node_times so that each node is
936
+ guaranteed to be older than any of its their children.
920
937
"""
921
- new_mn_post = np .copy (post_mn )
922
- if nodes_to_date is None :
923
- nodes_to_date = np .arange (ts .num_nodes , dtype = np .uint64 )
924
- nodes_to_date = nodes_to_date [~ np .isin (nodes_to_date , ts .samples ())]
925
-
926
- tables = ts .tables
927
- parents = tables .edges .parent
928
- nd_children = tables .edges .child [np .argsort (parents )]
929
- parents = sorted (parents )
930
- parents_unique = np .unique (parents , return_index = True )
931
- parent_indices = parents_unique [1 ][np .isin (parents_unique [0 ], nodes_to_date )]
932
- for index , nd in tqdm (
933
- enumerate (sorted (nodes_to_date )), desc = "Constrain Ages" , disable = not progress
938
+ edges_parent = ts .edges_parent
939
+ edges_child = ts .edges_child
940
+
941
+ new_node_times = np .copy (node_times )
942
+ # Traverse through the ARG, ensuring children come before parents.
943
+ # This can be done by iterating over groups of edges with the same parent
944
+ new_parent_edge_idx = np .where (np .diff (edges_parent ) != 0 )[0 ] + 1
945
+ for edges_start , edges_end in tqdm (
946
+ zip (
947
+ itertools .chain ([0 ], new_parent_edge_idx ),
948
+ itertools .chain (new_parent_edge_idx , [len (edges_parent )]),
949
+ ),
950
+ desc = "Constrain Ages" ,
951
+ disable = not progress ,
934
952
):
935
- if index + 1 != len (nodes_to_date ):
936
- children_index = np .arange (parent_indices [index ], parent_indices [index + 1 ])
937
- else :
938
- children_index = np .arange (parent_indices [index ], ts .num_edges )
939
- children = nd_children [children_index ]
940
- time = np .max (new_mn_post [children ])
941
- if new_mn_post [nd ] <= time :
942
- new_mn_post [nd ] = time + eps
943
- return new_mn_post
953
+ parent = edges_parent [edges_start ]
954
+ child_ids = edges_child [edges_start :edges_end ] # May contain dups
955
+ oldest_child_time = np .max (new_node_times [child_ids ])
956
+ if oldest_child_time >= new_node_times [parent ]:
957
+ new_node_times [parent ] = oldest_child_time + eps
958
+ return new_node_times
944
959
945
960
946
961
def date (
@@ -1031,7 +1046,7 @@ def date(
1031
1046
progress = progress ,
1032
1047
** kwargs
1033
1048
)
1034
- constrained = constrain_ages_topo (tree_sequence , dates , eps , nds , progress )
1049
+ constrained = constrain_ages_topo (tree_sequence , dates , eps , progress )
1035
1050
tables = tree_sequence .dump_tables ()
1036
1051
tables .time_units = time_units
1037
1052
tables .nodes .time = constrained
@@ -1080,8 +1095,6 @@ def get_dates(
1080
1095
1081
1096
:return: tuple(mn_post, posterior, timepoints, eps, nodes_to_date)
1082
1097
"""
1083
- fixed_nodes = set (tree_sequence .samples ())
1084
-
1085
1098
# Default to not creating approximate priors unless ts has > 1000 samples
1086
1099
approx_priors = False
1087
1100
if tree_sequence .num_samples > 1000 :
@@ -1109,6 +1122,8 @@ def get_dates(
1109
1122
)
1110
1123
priors = priors
1111
1124
1125
+ fixed_nodes = set (priors .fixed_node_ids ())
1126
+
1112
1127
if probability_space != base .LOG :
1113
1128
liklhd = Likelihoods (
1114
1129
tree_sequence ,
0 commit comments