Skip to content

Commit 72cb883

Browse files
authored
Merge pull request #252 from nspope/ne-rescale-time-fixup
Fixup for #248, use generational scale for custom time grid
2 parents 77d4068 + 73d08ba commit 72cb883

File tree

2 files changed

+17
-4
lines changed

2 files changed

+17
-4
lines changed

tests/test_functions.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from tsdate.prior import ConditionalCoalescentTimes
5050
from tsdate.prior import fill_priors
5151
from tsdate.prior import gamma_approx
52+
from tsdate.prior import MixturePrior
5253
from tsdate.prior import PriorParams
5354
from tsdate.prior import SpansBySamples
5455
from tsdate.util import nodes_time_unconstrained
@@ -508,6 +509,14 @@ def test_two_tree_mutation_ts_intervals(self):
508509
tests = self.check_intervals(ts, delete_interval_ts, keep_interval_ts)
509510
assert np.all(tests)
510511

512+
def test_custom_timegrid_is_not_rescaled(self):
513+
ts = utility_functions.two_tree_mutation_ts()
514+
prior = MixturePrior(ts)
515+
demography = PopulationSizeHistory(3)
516+
timepoints = np.array([0, 300, 1000, 2000])
517+
prior_grid = prior.make_discretized_prior(demography, timepoints=timepoints)
518+
assert np.array_equal(prior_grid.timepoints, timepoints)
519+
511520

512521
class TestPriorVals:
513522
def verify_prior_vals(self, ts, prior_distr, **kwargs):

tsdate/prior.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -980,10 +980,7 @@ def fill_priors(
980980
datable_nodes[ts.samples()] = False
981981
datable_nodes = np.where(datable_nodes)[0]
982982

983-
if isinstance(population_size, (int, float, np.ndarray)):
984-
population_size = demography.PopulationSizeHistory(population_size)
985-
986-
# convert coalescent time grid to generational time scale
983+
# convert timepoints to generational timescale
987984
prior_times = base.NodeGridValues(
988985
ts.num_nodes,
989986
datable_nodes[np.argsort(ts.tables.nodes.time[datable_nodes])].astype(np.int32),
@@ -994,6 +991,7 @@ def fill_priors(
994991
for node in tqdm(
995992
datable_nodes, desc="Assign Prior to Each Node", disable=not progress
996993
):
994+
# NB: prior CDF is evaluated on coalescent timescale
997995
with np.errstate(divide="ignore", invalid="ignore"):
998996
prior_node = cdf_func(timepoints, main_param[node], scale=scale_param[node])
999997
# force age to be less than max value
@@ -1063,6 +1061,9 @@ def make_discretized_prior(self, population_size, timepoints=20, progress=False)
10631061
Calculate prior grid for a set of timepoints and a population size history
10641062
"""
10651063

1064+
if isinstance(population_size, (int, float, np.ndarray)):
1065+
population_size = demography.PopulationSizeHistory(population_size)
1066+
10661067
if isinstance(timepoints, int):
10671068
if timepoints < 2:
10681069
raise ValueError("You must have at least 2 time points")
@@ -1080,6 +1081,9 @@ def make_discretized_prior(self, population_size, timepoints=20, progress=False)
10801081
raise ValueError("Timepoints cannot be negative")
10811082
elif np.any(np.unique(timepoints, return_counts=True)[1] > 1):
10821083
raise ValueError("Timepoints cannot have duplicate values")
1084+
# timepoints are assumed to be on generational scale, so convert to
1085+
# coalescent timescale to evaluate prior
1086+
timepoints = population_size.to_coalescent_timescale(timepoints)
10831087
else:
10841088
raise ValueError(
10851089
"time_slices must be an integer or a numpy array of floats"

0 commit comments

Comments
 (0)