Skip to content

Commit bed072b

Browse files
committed
Utility to rescale time
1 parent 1a1f132 commit bed072b

File tree

2 files changed

+48
-17
lines changed

2 files changed

+48
-17
lines changed

tsdate/prior.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -974,8 +974,8 @@ def fill_priors(
974974
datable_nodes[ts.samples()] = False
975975
datable_nodes = np.where(datable_nodes)[0]
976976

977-
rescaled_timepoints = util.rescale_time_by_population_size(
978-
timepoints, population_size
977+
rescaled_timepoints, _, _ = util.change_time_measure(
978+
timepoints, population_size[:, 0], 1 / (2 * population_size[:, 1])
979979
)
980980

981981
prior_times = base.NodeGridValues(
@@ -1050,27 +1050,21 @@ def build_grid(
10501050
)
10511051
if population_size.shape[1] != 2:
10521052
raise ValueError(
1053-
"'population_size' array must have two columns that contain \
1053+
"Population size array must have two columns that contain \
10541054
epoch start times and population sizes, respectively"
10551055
)
10561056
if np.any(population_size[:, 0] < 0.0):
1057-
raise ValueError(
1058-
"Epoch start times in 'population_size' array must be nonnegative"
1059-
)
1057+
raise ValueError("Epoch start times must be nonnegative")
10601058
if np.any(population_size[:, 1] <= 0.0):
1061-
raise ValueError(
1062-
"Population sizes in 'population_size' array must be positive "
1063-
)
1059+
raise ValueError("Population sizes must be positive ")
10641060
if population_size[0, 0] != 0:
1065-
raise ValueError(
1066-
"The first epoch in 'population_size' array must start at time 0"
1067-
)
1061+
raise ValueError("The first epoch must start at time 0")
10681062
if not np.all(np.diff(population_size[:, 0]) > 0):
1069-
raise ValueError(
1070-
"Epoch start times 'population_size' array must be unique and increasing"
1071-
)
1072-
elif population_size <= 0:
1073-
raise ValueError("Scalar 'population_size' must be greater than 0")
1063+
raise ValueError("Epoch start times must be unique and increasing")
1064+
else:
1065+
if population_size <= 0:
1066+
raise ValueError("Scalar 'population_size' must be greater than 0")
1067+
population_size = np.array([[0, population_size]], dtype=float)
10741068
if approximate_priors:
10751069
if not approx_prior_size:
10761070
approx_prior_size = 1000

tsdate/util.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,3 +300,40 @@ def add_sampledata_times(samples, sites_time):
300300
copy.sites_time[:] = sites_time
301301
copy.finalise()
302302
return copy
303+
304+
305+
def change_time_measure(time_ago, breakpoints, time_measure):
306+
"""
307+
Rescales time given a piecewise-constant time measure (e.g. a piecewise
308+
constant demographic history). To convert from generations to coalescent
309+
units, the time measure per epoch should be 2 * effective population size. To
310+
convert from coalescent units to generations, the time measure should be
311+
the coalescent rate ``1/(2 * Ne)``.
312+
313+
:param np.ndarray time_ago: An increasing vector of time points
314+
:param np.ndarray breakpoints: Start times of pieces
315+
:param np.ndarray time_measure: Time measure within pieces
316+
317+
:return: Inputs in new time measure
318+
"""
319+
320+
assert sorted(breakpoints)
321+
assert np.min(breakpoints) == 0.0
322+
assert np.all(time_ago >= 0.0)
323+
assert np.all(time_measure > 0.0)
324+
325+
index = np.searchsorted(breakpoints, time_ago, side="right") - 1
326+
step = np.concatenate(
327+
[
328+
[0.0],
329+
np.cumsum(
330+
breakpoints[1:] * (1.0 / time_measure[:-1] - 1.0 / time_measure[1:])
331+
),
332+
]
333+
)
334+
335+
new_time_ago = time_ago * 1.0 / time_measure[index] + step[index]
336+
new_breakpoints = breakpoints * 1.0 / time_measure + step
337+
new_time_measure = 1.0 / time_measure
338+
339+
return new_time_ago, new_breakpoints, new_time_measure

0 commit comments

Comments
 (0)