Skip to content

Commit 377aae6

Browse files
committed
Move scalar-PopulationSizeHistory conversion into fill_grid
1 parent a40d1a8 commit 377aae6

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

tsdate/prior.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -958,7 +958,12 @@ def fill_priors(
958958
and fill out a NodeGridValues object with the prior values from the
959959
gamma or lognormal distribution with those parameters.
960960
961+
The `population_size` can be a scalar, or an object with a `.to_natural_timescale`
962+
method used to map from coalescent to generational timescale.
963+
961964
TODO - what if there is an internal fixed node? Should we truncate
965+
966+
TODO - support times scaled by generation length?
962967
"""
963968
if prior_distr == "lognorm":
964969
cdf_func = scipy.stats.lognorm.cdf
@@ -975,7 +980,10 @@ def fill_priors(
975980
datable_nodes[ts.samples()] = False
976981
datable_nodes = np.where(datable_nodes)[0]
977982

978-
# convert coalescent time grid to generations
983+
if isinstance(population_size, (int, float, np.ndarray)):
984+
population_size = demography.PopulationSizeHistory(population_size)
985+
986+
# convert coalescent time grid to generational time scale
979987
prior_times = base.NodeGridValues(
980988
ts.num_nodes,
981989
datable_nodes[np.argsort(ts.tables.nodes.time[datable_nodes])].astype(np.int32),
@@ -1055,9 +1063,6 @@ def make_discretized_prior(self, population_size, timepoints=20, progress=False)
10551063
Calculate prior grid for a set of timepoints and a population size history
10561064
"""
10571065

1058-
if isinstance(population_size, (int, float, np.ndarray)):
1059-
population_size = demography.PopulationSizeHistory(population_size)
1060-
10611066
if isinstance(timepoints, int):
10621067
if timepoints < 2:
10631068
raise ValueError("You must have at least 2 time points")

0 commit comments

Comments
 (0)