@@ -991,6 +991,59 @@ def fill_priors(node_parameters, timepoints, ts, Ne, *, prior_distr, progress=Fa
991
991
return prior_times
992
992
993
993
994
+ def truncate_priors (ts , sample_times , priors , nodes_to_date = None , progress = False ):
995
+ """
996
+ Truncate priors so they conform to the age of nodes in the tree sequence
997
+ """
998
+ grid_data = np .copy (priors .grid_data [:])
999
+ timepoints = priors .timepoints
1000
+ if np .max (sample_times ) >= np .max (timepoints ):
1001
+ raise ValueError ("Sample times cannot be larger than the oldest timepoint" )
1002
+ if priors .probability_space == "linear" :
1003
+ zero_value = 0
1004
+ one_value = 1
1005
+ elif priors .probability_space == "logarithmic" :
1006
+ zero_value = - np .inf
1007
+ one_value = 0
1008
+ constrained_min_times = np .copy (sample_times )
1009
+ constrained_max_times = np .full (sample_times .shape [0 ], np .inf )
1010
+ if nodes_to_date is None :
1011
+ nodes_to_date = np .arange (ts .num_nodes , dtype = np .uint64 )
1012
+ nodes_to_date = nodes_to_date [~ np .isin (nodes_to_date , ts .samples ())]
1013
+
1014
+ tables = ts .tables
1015
+ parents = tables .edges .parent
1016
+ nd_children = tables .edges .child [np .argsort (parents )]
1017
+ parents = sorted (parents )
1018
+ parents_unique = np .unique (parents , return_index = True )
1019
+ parent_indices = parents_unique [1 ][np .isin (parents_unique [0 ], nodes_to_date )]
1020
+ for index , nd in tqdm (
1021
+ enumerate (sorted (nodes_to_date )), desc = "Constrain Ages" , disable = not progress
1022
+ ):
1023
+ if index + 1 != len (nodes_to_date ):
1024
+ children_index = np .arange (parent_indices [index ], parent_indices [index + 1 ])
1025
+ else :
1026
+ children_index = np .arange (parent_indices [index ], ts .num_edges )
1027
+ children = nd_children [children_index ]
1028
+ time = np .max (constrained_min_times [children ])
1029
+ # The constrained time of the node should be the age of the oldest child
1030
+ if constrained_min_times [nd ] <= time :
1031
+ constrained_min_times [nd ] = time
1032
+ nearest_time = np .argmin (np .abs (timepoints - time ))
1033
+ lookup_index = priors .row_lookup [int (nd )]
1034
+ grid_data [lookup_index ][:nearest_time ] = zero_value
1035
+ assert np .all (constrained_min_times < constrained_max_times )
1036
+ all_zeros = np .where (np .all (grid_data == zero_value , axis = 1 ))[0 ]
1037
+
1038
+ rowmax = grid_data [:, 1 :].max (axis = 1 )
1039
+ if priors .probability_space == "linear" :
1040
+ grid_data = grid_data / rowmax [:, np .newaxis ]
1041
+ elif priors .probability_space == "logarithmic" :
1042
+ grid_data = grid_data - rowmax [:, np .newaxis ]
1043
+
1044
+ priors .grid_data [:] = grid_data
1045
+ return constrained_min_times , constrained_max_times , priors
1046
+
994
1047
def build_grid (
995
1048
tree_sequence ,
996
1049
Ne ,
@@ -1001,6 +1054,7 @@ def build_grid(
1001
1054
prior_distribution = "lognorm" ,
1002
1055
eps = 1e-6 ,
1003
1056
progress = False ,
1057
+ sample_times = None
1004
1058
):
1005
1059
"""
1006
1060
Using the conditional coalescent, calculate the prior distribution for the age of
@@ -1084,4 +1138,7 @@ def build_grid(
1084
1138
prior_distr = prior_distribution ,
1085
1139
progress = progress ,
1086
1140
)
1141
+ if np .any (tree_sequence .tables .nodes .time [tree_sequence .samples ()] != 0 ):
1142
+ if False :
1143
+ priors = truncate_priors (tree_sequence , sample_times , priors , eps , progress = progress )
1087
1144
return priors
0 commit comments