@@ -1032,14 +1032,10 @@ def shape_scale_from_mean_var(mean, var):
1032
1032
1033
1033
def _truncate_priors (ts , priors , progress = False ):
1034
1034
"""
1035
- Truncate priors for the nodes listed in truncate_nodes (or all nonfixed nodes
1036
- if truncate_nodes is None) so they conform to the age of fixed nodes in the tree
1037
- sequence
1035
+ Truncate priors for all nonfixed nodes
1036
+ so they conform to the age of fixed nodes in the tree sequence
1038
1037
"""
1039
1038
tables = ts .tables
1040
- truncate_nodes = priors .nonfixed_node_ids ()
1041
- # ensure truncate_nodes is ordered by node time
1042
- truncate_nodes = truncate_nodes [np .argsort (tables .nodes .time [truncate_nodes ])]
1043
1039
1044
1040
fixed_nodes = priors .fixed_node_ids ()
1045
1041
fixed_times = tables .nodes .time [fixed_nodes ]
@@ -1055,29 +1051,32 @@ def _truncate_priors(ts, priors, progress=False):
1055
1051
constrained_min_times = np .zeros_like (tables .nodes .time )
1056
1052
# Set the min times of fixed nodes to those in the tree sequence
1057
1053
constrained_min_times [fixed_nodes ] = fixed_times
1058
- constrained_max_times = np .full_like (constrained_min_times , np .inf )
1059
-
1060
- parents = tables .edges .parent
1061
- nd_children = tables .edges .child [np .argsort (parents )]
1062
- parents = sorted (parents )
1063
- parents_unique = np .unique (parents , return_index = True )
1064
- parent_indices = parents_unique [1 ][np .isin (parents_unique [0 ], truncate_nodes )]
1065
- for index , nd in tqdm (
1066
- enumerate (truncate_nodes ), desc = "Constrain Ages" , disable = not progress
1054
+
1055
+ # Traverse through the ARG, ensuring children come before parents.
1056
+ # This can be done by iterating over groups of edges with the same parent
1057
+ new_parent_edge_idx = np .concatenate (
1058
+ (
1059
+ [0 ],
1060
+ np .where (np .diff (tables .edges .parent ) != 0 )[0 ] + 1 ,
1061
+ [tables .edges .num_rows ],
1062
+ )
1063
+ )
1064
+ for edges_start , edges_end in zip (
1065
+ new_parent_edge_idx [:- 1 ], new_parent_edge_idx [1 :]
1067
1066
):
1068
- if index + 1 != len ( truncate_nodes ):
1069
- children_index = np . arange ( parent_indices [ index ], parent_indices [ index + 1 ])
1070
- else :
1071
- children_index = np . arange ( parent_indices [ index ], ts . num_edges )
1072
- children = nd_children [ children_index ]
1073
- time = np . max ( constrained_min_times [ children ])
1074
- # The constrained time of the node should be the age of the oldest child
1075
- if constrained_min_times [ nd ] <= time :
1076
- constrained_min_times [ nd ] = time
1077
- nearest_time = np . argmin ( np . abs ( timepoints - time ))
1078
- lookup_index = priors .row_lookup [ int ( nd )]
1079
- grid_data [ lookup_index ][: nearest_time ] = zero_value
1080
- assert np . all ( constrained_min_times < constrained_max_times )
1067
+ parent = tables . edges . parent [ edges_start ]
1068
+ child_ids = tables . edges . child [ edges_start : edges_end ] # May contain dups
1069
+ oldest_child_time = np . max ( constrained_min_times [ child_ids ])
1070
+ if oldest_child_time > constrained_min_times [ parent ]:
1071
+ if priors . is_fixed ( parent ):
1072
+ raise ValueError (
1073
+ "Invalid fixed times: time for"
1074
+ + f"fixed node { parent } is younger than some of its descendants"
1075
+ )
1076
+ constrained_min_times [ parent ] = oldest_child_time
1077
+ if constrained_min_times [ parent ] > 0 and not priors .is_fixed ( parent ):
1078
+ nearest_time = np . argmin ( np . abs ( timepoints - constrained_min_times [ parent ]))
1079
+ grid_data [ priors . row_lookup [ parent ]][: nearest_time ] = zero_value
1081
1080
1082
1081
rowmax = grid_data [:, 1 :].max (axis = 1 )
1083
1082
if priors .probability_space == "linear" :
@@ -1138,7 +1137,7 @@ def build_grid(
1138
1137
:param dict node_var_override: is a dict mapping node IDs to a variance value.
1139
1138
Any nodes listed here will be treated as non-fixed nodes whose prior is not
1140
1139
calculated from the conditional coalescent but instead are allocated a prior
1141
- whose mean is thenode time in the tree sequence and whose variance is the
1140
+ whose mean is the node time in the tree sequence and whose variance is the
1142
1141
value in this dictionary. This allows sample nodes to be treated as nonfixed
1143
1142
nodes, and therefore dated. If ``None`` (default) then all sample nodes are
1144
1143
treated as occurring ata fixed time (as if this were an empty dict).
0 commit comments