@@ -653,7 +653,10 @@ def inside_pass(self, *, normalize=True, cache_inside=False, progress=None):
653
653
self .priors .nonfixed_node_ids (), self .ts .samples ()
654
654
)
655
655
for u in nonfixed_samples :
656
- inside [u ][:] = self .priors [u ]
656
+ # this is in the same probability space as the prior, so we should be
657
+ # OK just to copy the prior values straight in (but we should check they
658
+ # are normalised so that they sum to unity)
659
+ inside [u ][:] = self .priors .sum_to_unity (self .priors [u ])
657
660
658
661
if cache_inside :
659
662
g_i = np .full (
@@ -921,34 +924,31 @@ def posterior_mean_var(ts, posterior, *, fixed_node_set=None):
921
924
return ts , mn_post , vr_post
922
925
923
926
924
- def constrain_ages_topo (ts , post_mn , eps , nodes_to_date = None , progress = False ):
927
+ def constrain_ages_topo (ts , node_times , eps , progress = False ):
925
928
"""
926
- If predicted node times violate topology, restrict node ages so that they
927
- must be older than all their children.
929
+ If node_times violate topology, return increased node_times so that each node is
930
+ guaranteed to be older than any of its their children.
928
931
"""
929
- new_mn_post = np .copy (post_mn )
930
- if nodes_to_date is None :
931
- nodes_to_date = np .arange (ts .num_nodes , dtype = np .uint64 )
932
- nodes_to_date = nodes_to_date [~ np .isin (nodes_to_date , ts .samples ())]
933
-
934
932
tables = ts .tables
935
- parents = tables .edges .parent
936
- nd_children = tables .edges .child [np .argsort (parents )]
937
- parents = sorted (parents )
938
- parents_unique = np .unique (parents , return_index = True )
939
- parent_indices = parents_unique [1 ][np .isin (parents_unique [0 ], nodes_to_date )]
940
- for index , nd in tqdm (
941
- enumerate (sorted (nodes_to_date )), desc = "Constrain Ages" , disable = not progress
933
+ new_node_times = np .copy (node_times )
934
+ # Traverse through the ARG, ensuring children come before parents.
935
+ # This can be done by iterating over groups of edges with the same parent
936
+ new_parent_edge_idx = np .concatenate (
937
+ (
938
+ [0 ],
939
+ np .where (np .diff (tables .edges .parent ) != 0 )[0 ] + 1 ,
940
+ [tables .edges .num_rows ],
941
+ )
942
+ )
943
+ for edges_start , edges_end in zip (
944
+ new_parent_edge_idx [:- 1 ], new_parent_edge_idx [1 :]
942
945
):
943
- if index + 1 != len (nodes_to_date ):
944
- children_index = np .arange (parent_indices [index ], parent_indices [index + 1 ])
945
- else :
946
- children_index = np .arange (parent_indices [index ], ts .num_edges )
947
- children = nd_children [children_index ]
948
- time = np .max (new_mn_post [children ])
949
- if new_mn_post [nd ] <= time :
950
- new_mn_post [nd ] = time + eps
951
- return new_mn_post
946
+ parent = tables .edges .parent [edges_start ]
947
+ child_ids = tables .edges .child [edges_start :edges_end ] # May contain dups
948
+ oldest_child_time = np .max (new_node_times [child_ids ])
949
+ if oldest_child_time >= new_node_times [parent ]:
950
+ new_node_times [parent ] = oldest_child_time + eps
951
+ return new_node_times
952
952
953
953
954
954
def date (
@@ -1039,7 +1039,7 @@ def date(
1039
1039
progress = progress ,
1040
1040
** kwargs
1041
1041
)
1042
- constrained = constrain_ages_topo (tree_sequence , dates , eps , nds , progress )
1042
+ constrained = constrain_ages_topo (tree_sequence , dates , eps , progress )
1043
1043
tables = tree_sequence .dump_tables ()
1044
1044
tables .time_units = time_units
1045
1045
tables .nodes .time = constrained
0 commit comments