@@ -991,38 +991,42 @@ def fill_priors(node_parameters, timepoints, ts, Ne, *, prior_distr, progress=Fa
991
991
return prior_times
992
992
993
993
994
- def _truncate_priors (ts , priors , nodes_to_date = None , progress = False ):
994
+ def _truncate_priors (ts , priors , progress = False ):
995
995
"""
996
- Truncate priors so they conform to the age of nodes in the tree sequence
996
+ Truncate priors for the nodes listed in truncate_nodes (or all nonfixed nodes
997
+ if truncate_nodes in None) so they conform to the age of fixed nodes in the tree
998
+ sequence
997
999
"""
998
1000
tables = ts .tables
999
- if nodes_to_date is None :
1000
- nodes_to_date = np .arange (ts .num_nodes , dtype = np .uint64 )
1001
- nodes_to_date = nodes_to_date [~ np .isin (nodes_to_date , ts .samples ())]
1002
- # ensure nodes_to_date is ordered by node time
1003
- nodes_to_date = nodes_to_date [np .argsort (tables .nodes .time [nodes_to_date ])]
1001
+ truncate_nodes = priors .nonfixed_node_ids ()
1002
+ # ensure truncate_nodes is ordered by node time
1003
+ truncate_nodes = truncate_nodes [np .argsort (tables .nodes .time [truncate_nodes ])]
1004
+
1005
+ fixed_nodes = priors .fixed_node_ids ()
1006
+ fixed_times = tables .nodes .time [fixed_nodes ]
1004
1007
1005
1008
grid_data = np .copy (priors .grid_data [:])
1006
1009
timepoints = priors .timepoints
1007
- if np .max (tables . nodes . time [ ts . samples ()] ) >= np .max (timepoints ):
1008
- raise ValueError ("Sample times cannot be larger than the oldest timepoint" )
1010
+ if np .max (fixed_times ) >= np .max (timepoints ):
1011
+ raise ValueError ("Fixed node times cannot be older than the oldest timepoint" )
1009
1012
if priors .probability_space == "linear" :
1010
1013
zero_value = 0
1011
1014
elif priors .probability_space == "logarithmic" :
1012
1015
zero_value = - np .inf
1013
1016
constrained_min_times = np .zeros_like (tables .nodes .time )
1014
- constrained_min_times [ts .samples ()] = tables .nodes .time [ts .samples ()]
1017
+ # Set the min times of fixed nodes to those in the tree sequence
1018
+ constrained_min_times [fixed_nodes ] = fixed_times
1015
1019
constrained_max_times = np .full_like (constrained_min_times , np .inf )
1016
1020
1017
1021
parents = tables .edges .parent
1018
1022
nd_children = tables .edges .child [np .argsort (parents )]
1019
1023
parents = sorted (parents )
1020
1024
parents_unique = np .unique (parents , return_index = True )
1021
- parent_indices = parents_unique [1 ][np .isin (parents_unique [0 ], nodes_to_date )]
1025
+ parent_indices = parents_unique [1 ][np .isin (parents_unique [0 ], truncate_nodes )]
1022
1026
for index , nd in tqdm (
1023
- enumerate (nodes_to_date ), desc = "Constrain Ages" , disable = not progress
1027
+ enumerate (truncate_nodes ), desc = "Constrain Ages" , disable = not progress
1024
1028
):
1025
- if index + 1 != len (nodes_to_date ):
1029
+ if index + 1 != len (truncate_nodes ):
1026
1030
children_index = np .arange (parent_indices [index ], parent_indices [index + 1 ])
1027
1031
else :
1028
1032
children_index = np .arange (parent_indices [index ], ts .num_edges )
0 commit comments