@@ -980,10 +980,7 @@ def fill_priors(
980
980
datable_nodes [ts .samples ()] = False
981
981
datable_nodes = np .where (datable_nodes )[0 ]
982
982
983
- if isinstance (population_size , (int , float , np .ndarray )):
984
- population_size = demography .PopulationSizeHistory (population_size )
985
-
986
- # convert coalescent time grid to generational time scale
983
+ # convert timepoints to generational timescale
987
984
prior_times = base .NodeGridValues (
988
985
ts .num_nodes ,
989
986
datable_nodes [np .argsort (ts .tables .nodes .time [datable_nodes ])].astype (np .int32 ),
@@ -994,6 +991,7 @@ def fill_priors(
994
991
for node in tqdm (
995
992
datable_nodes , desc = "Assign Prior to Each Node" , disable = not progress
996
993
):
994
+ # NB: prior CDF is evaluated on coalescent timescale
997
995
with np .errstate (divide = "ignore" , invalid = "ignore" ):
998
996
prior_node = cdf_func (timepoints , main_param [node ], scale = scale_param [node ])
999
997
# force age to be less than max value
@@ -1063,6 +1061,9 @@ def make_discretized_prior(self, population_size, timepoints=20, progress=False)
1063
1061
Calculate prior grid for a set of timepoints and a population size history
1064
1062
"""
1065
1063
1064
+ if isinstance (population_size , (int , float , np .ndarray )):
1065
+ population_size = demography .PopulationSizeHistory (population_size )
1066
+
1066
1067
if isinstance (timepoints , int ):
1067
1068
if timepoints < 2 :
1068
1069
raise ValueError ("You must have at least 2 time points" )
@@ -1080,6 +1081,9 @@ def make_discretized_prior(self, population_size, timepoints=20, progress=False)
1080
1081
raise ValueError ("Timepoints cannot be negative" )
1081
1082
elif np .any (np .unique (timepoints , return_counts = True )[1 ] > 1 ):
1082
1083
raise ValueError ("Timepoints cannot have duplicate values" )
1084
+ # timepoints are assumed to be on generational scale, so convert to
1085
+ # coalescent timescale to evaluate prior
1086
+ timepoints = population_size .to_coalescent_timescale (timepoints )
1083
1087
else :
1084
1088
raise ValueError (
1085
1089
"time_slices must be an integer or a numpy array of floats"
0 commit comments