Skip to content

Commit 1a1f132

Browse files
committed
Rescale by population size outside of prior
1 parent f33675d commit 1a1f132

File tree

1 file changed

+37
-10
lines changed

1 file changed

+37
-10
lines changed

tsdate/prior.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ class ConditionalCoalescentTimes:
7272
def __init__(
7373
self,
7474
precalc_approximation_n,
75-
population_size,
7675
prior_distr="lognorm",
7776
progress=False,
7877
):
@@ -83,7 +82,6 @@ def __init__(
8382
and therefore do not allow approximate priors to be used
8483
"""
8584
self.n_approx = precalc_approximation_n
86-
self.population_size = population_size
8785
self.prior_store = {}
8886
self.progress = progress
8987

@@ -181,8 +179,7 @@ def add(self, total_tips, approximate=None):
181179
priors[1] = PriorParams(alpha=0, beta=1, mean=0, var=0)
182180
for var, tips in zip(variances, all_tips):
183181
# NB: it should be possible to vectorize this in numpy
184-
var = var * ((2 * self.Ne) ** 2) # TODO
185-
expectation = self.tau_expect(tips, total_tips) * 2 * self.Ne # TODO
182+
expectation = self.tau_expect(tips, total_tips)
186183
alpha, beta = self.func_approx(expectation, var)
187184
priors[tips] = PriorParams(
188185
alpha=alpha, beta=beta, mean=expectation, var=var
@@ -976,10 +973,15 @@ def fill_priors(
976973
datable_nodes = np.ones(ts.num_nodes, dtype=bool)
977974
datable_nodes[ts.samples()] = False
978975
datable_nodes = np.where(datable_nodes)[0]
976+
977+
rescaled_timepoints = util.rescale_time_by_population_size(
978+
timepoints, population_size
979+
)
980+
979981
prior_times = base.NodeGridValues(
980982
ts.num_nodes,
981983
datable_nodes[np.argsort(ts.tables.nodes.time[datable_nodes])].astype(np.int32),
982-
timepoints,
984+
rescaled_timepoints,
983985
)
984986

985987
# TO DO - this can probably be done in an single numpy step rather than a for loop
@@ -1020,7 +1022,7 @@ def build_grid(
10201022
:param float population_size: The estimated (diploid) effective population
10211023
size: must be specified. May be a single value, or a two-column array with
10221024
epoch breakpoints and effective population sizes. Using standard (unscaled)
1023-
values for ``population_size`` results in a prior where times are measures
1025+
values for ``population_size`` results in a prior where times are measured
10241026
in generations.
10251027
:param int_or_array_like timepoints: The number of quantiles used to create the
10261028
time slices, or manually-specified time slices as a numpy array. Default: 20
@@ -1041,9 +1043,34 @@ def build_grid(
10411043
inference and a discretised time grid
10421044
:rtype: base.NodeGridValues Object
10431045
"""
1044-
# TODO
1045-
if population_size <= 0:
1046-
raise ValueError("Parameter 'population_size' must be greater than 0")
1046+
if isinstance(population_size, np.ndarray):
1047+
if population_size.ndim != 2:
1048+
raise ValueError(
1049+
"Parameter 'population_size' must be a scalar or a 2d array"
1050+
)
1051+
if population_size.shape[1] != 2:
1052+
raise ValueError(
1053+
"'population_size' array must have two columns that contain \
1054+
epoch start times and population sizes, respectively"
1055+
)
1056+
if np.any(population_size[:, 0] < 0.0):
1057+
raise ValueError(
1058+
"Epoch start times in 'population_size' array must be nonnegative"
1059+
)
1060+
if np.any(population_size[:, 1] <= 0.0):
1061+
raise ValueError(
1062+
"Population sizes in 'population_size' array must be positive "
1063+
)
1064+
if population_size[0, 0] != 0:
1065+
raise ValueError(
1066+
"The first epoch in 'population_size' array must start at time 0"
1067+
)
1068+
if not np.all(np.diff(population_size[:, 0]) > 0):
1069+
raise ValueError(
1070+
"Epoch start times 'population_size' array must be unique and increasing"
1071+
)
1072+
elif population_size <= 0:
1073+
raise ValueError("Scalar 'population_size' must be greater than 0")
10471074
if approximate_priors:
10481075
if not approx_prior_size:
10491076
approx_prior_size = 1000
@@ -1062,7 +1089,7 @@ def build_grid(
10621089
span_data = SpansBySamples(contmpr_ts, progress=progress, allow_unary=allow_unary)
10631090

10641091
base_priors = ConditionalCoalescentTimes(
1065-
approx_prior_size, population_size, prior_distribution, progress=progress
1092+
approx_prior_size, prior_distribution, progress=progress
10661093
)
10671094

10681095
base_priors.add(contmpr_ts.num_samples, approximate_priors)

0 commit comments

Comments
 (0)