@@ -996,38 +996,42 @@ def fill_priors(node_parameters, timepoints, ts, Ne, *, prior_distr, progress=Fa
996
996
return prior_times
997
997
998
998
999
- def _truncate_priors (ts , priors , nodes_to_date = None , progress = False ):
999
+ def _truncate_priors (ts , priors , progress = False ):
1000
1000
"""
1001
- Truncate priors so they conform to the age of nodes in the tree sequence
1001
+ Truncate priors for the nodes listed in truncate_nodes (or all nonfixed nodes
1002
+ if truncate_nodes in None) so they conform to the age of fixed nodes in the tree
1003
+ sequence
1002
1004
"""
1003
1005
tables = ts .tables
1004
- if nodes_to_date is None :
1005
- nodes_to_date = np .arange (ts .num_nodes , dtype = np .uint64 )
1006
- nodes_to_date = nodes_to_date [~ np .isin (nodes_to_date , ts .samples ())]
1007
- # ensure nodes_to_date is ordered by node time
1008
- nodes_to_date = nodes_to_date [np .argsort (tables .nodes .time [nodes_to_date ])]
1006
+ truncate_nodes = priors .nonfixed_node_ids ()
1007
+ # ensure truncate_nodes is ordered by node time
1008
+ truncate_nodes = truncate_nodes [np .argsort (tables .nodes .time [truncate_nodes ])]
1009
+
1010
+ fixed_nodes = priors .fixed_node_ids ()
1011
+ fixed_times = tables .nodes .time [fixed_nodes ]
1009
1012
1010
1013
grid_data = np .copy (priors .grid_data [:])
1011
1014
timepoints = priors .timepoints
1012
- if np .max (tables . nodes . time [ ts . samples ()] ) >= np .max (timepoints ):
1013
- raise ValueError ("Sample times cannot be larger than the oldest timepoint" )
1015
+ if np .max (fixed_times ) >= np .max (timepoints ):
1016
+ raise ValueError ("Fixed node times cannot be older than the oldest timepoint" )
1014
1017
if priors .probability_space == "linear" :
1015
1018
zero_value = 0
1016
1019
elif priors .probability_space == "logarithmic" :
1017
1020
zero_value = - np .inf
1018
1021
constrained_min_times = np .zeros_like (tables .nodes .time )
1019
- constrained_min_times [ts .samples ()] = tables .nodes .time [ts .samples ()]
1022
+ # Set the min times of fixed nodes to those in the tree sequence
1023
+ constrained_min_times [fixed_nodes ] = fixed_times
1020
1024
constrained_max_times = np .full_like (constrained_min_times , np .inf )
1021
1025
1022
1026
parents = tables .edges .parent
1023
1027
nd_children = tables .edges .child [np .argsort (parents )]
1024
1028
parents = sorted (parents )
1025
1029
parents_unique = np .unique (parents , return_index = True )
1026
- parent_indices = parents_unique [1 ][np .isin (parents_unique [0 ], nodes_to_date )]
1030
+ parent_indices = parents_unique [1 ][np .isin (parents_unique [0 ], truncate_nodes )]
1027
1031
for index , nd in tqdm (
1028
- enumerate (nodes_to_date ), desc = "Constrain Ages" , disable = not progress
1032
+ enumerate (truncate_nodes ), desc = "Constrain Ages" , disable = not progress
1029
1033
):
1030
- if index + 1 != len (nodes_to_date ):
1034
+ if index + 1 != len (truncate_nodes ):
1031
1035
children_index = np .arange (parent_indices [index ], parent_indices [index + 1 ])
1032
1036
else :
1033
1037
children_index = np .arange (parent_indices [index ], ts .num_edges )
0 commit comments