Skip to content

Commit a40d1a8

Browse files
committed
Refactor into PopulationSizeHistory class
1 parent 9327e00 commit a40d1a8

File tree

7 files changed

+228
-114
lines changed

7 files changed

+228
-114
lines changed

tests/test_accuracy.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,7 @@ def test_piecewise_scaling(self, bkwd_rate, trio_tmrca):
170170
time = np.linspace(0, 10, 100)
171171
ne = 0.5 * np.exp(bkwd_rate * time)
172172
ts = tskit.Tree.generate_comb(3).tree_sequence
173-
dts = tsdate.date(
174-
ts, population_size=np.column_stack([time, ne]), mutation_rate=None
175-
)
173+
demo = tsdate.demography.PopulationSizeHistory(ne, time[1:])
174+
dts = tsdate.date(ts, population_size=demo, mutation_rate=None)
176175
# Check the date is within 10% of the expected
177176
assert 0.9 < dts.node(dts.first().root).time / trio_tmrca < 1.1

tests/test_functions.py

Lines changed: 65 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,12 @@
4545
from tsdate.core import Likelihoods
4646
from tsdate.core import LogLikelihoods
4747
from tsdate.core import posterior_mean_var
48+
from tsdate.demography import PopulationSizeHistory
4849
from tsdate.prior import ConditionalCoalescentTimes
4950
from tsdate.prior import fill_priors
5051
from tsdate.prior import gamma_approx
5152
from tsdate.prior import PriorParams
5253
from tsdate.prior import SpansBySamples
53-
from tsdate.util import change_time_measure
5454
from tsdate.util import nodes_time_unconstrained
5555

5656

@@ -512,7 +512,7 @@ def test_two_tree_mutation_ts_intervals(self):
512512
class TestPriorVals:
513513
def verify_prior_vals(self, ts, prior_distr, **kwargs):
514514
span_data = SpansBySamples(ts, **kwargs)
515-
Ne = np.array([[0, 0.5]])
515+
Ne = PopulationSizeHistory(0.5)
516516
priors = ConditionalCoalescentTimes(None, prior_distr=prior_distr)
517517
priors.add(ts.num_samples, approximate=False)
518518
grid = np.linspace(0, 3, 3)
@@ -1102,7 +1102,7 @@ def run_outside_algorithm(
11021102
self, ts, prior_distr="lognorm", standardize=False, ignore_oldest_root=False
11031103
):
11041104
span_data = SpansBySamples(ts)
1105-
Ne = np.array([[0, 0.5]])
1105+
Ne = PopulationSizeHistory(0.5)
11061106
priors = ConditionalCoalescentTimes(None, prior_distr)
11071107
priors.add(ts.num_samples, approximate=False)
11081108
grid = np.array([0, 1.2, 2])
@@ -1205,7 +1205,7 @@ class TestTotalFunctionalValueTree:
12051205
def find_posterior(self, ts, prior_distr):
12061206
grid = np.array([0, 1.2, 2])
12071207
span_data = SpansBySamples(ts)
1208-
Ne = np.array([[0, 0.5]])
1208+
Ne = PopulationSizeHistory(0.5)
12091209
priors = ConditionalCoalescentTimes(None, prior_distr=prior_distr)
12101210
priors.add(ts.num_samples, approximate=False)
12111211
mixture_priors = priors.get_mixture_prior_params(span_data)
@@ -1269,7 +1269,7 @@ def test_gil_tree(self):
12691269
ts = utility_functions.gils_example_tree()
12701270
span_data = SpansBySamples(ts)
12711271
prior_distr = "lognorm"
1272-
Ne = np.array([[0, 0.5]])
1272+
Ne = PopulationSizeHistory(0.5)
12731273
priors = ConditionalCoalescentTimes(None, prior_distr=prior_distr)
12741274
priors.add(ts.num_samples, approximate=False)
12751275
grid = np.array([0, 0.1, 0.2, 0.5, 1, 2, 5])
@@ -1978,45 +1978,95 @@ def test_historical_samples(self):
19781978
)
19791979

19801980

1981-
class TestRescaleTime:
1982-
def test_rescale(self):
1981+
class TestPopulationSizeHistory:
1982+
def test_change_time_measure_scalar(self):
19831983
Ne = 10000
19841984
coaltime = np.array([0, 1, 2, 3])
19851985
coalstart = np.array([0])
19861986
coalrate = np.array([1 / (2 * Ne)])
1987-
gens, _, _ = change_time_measure(coaltime, coalstart, coalrate)
1987+
gens, _, _ = PopulationSizeHistory._change_time_measure(
1988+
coaltime, coalstart, coalrate
1989+
)
19881990
assert np.allclose(gens, 2 * coaltime * Ne)
19891991

1990-
def test_piecewise(self):
1992+
def test_change_time_measure_piecewise(self):
19911993
Ne = np.array([2000, 3000, 5000])
19921994
start = np.array([0, 4000, 10000])
19931995
gens = np.array([2000, 7000, 15000])
1994-
coaltime, coalstart, coalrate = change_time_measure(gens, start, 2 * Ne)
1996+
coaltime, coalstart, coalrate = PopulationSizeHistory._change_time_measure(
1997+
gens, start, 2 * Ne
1998+
)
19951999
assert np.allclose(coalstart, np.array([0, 1, 2]))
19962000
assert np.allclose(coaltime, np.array([0.5, 1.5, 2.5]))
19972001
assert np.allclose(coalrate, 1 / (2 * Ne))
19982002

1999-
def test_piecewise_bijection(self):
2003+
def test_change_time_measure_bijection(self):
20002004
hapNe = np.array([2000, 3000, 5000])
20012005
start = np.array([0, 4000, 10000])
20022006
gens = np.array([500, 7000, 15000])
2003-
coaltime, coalstart, coalrate = change_time_measure(gens, start, hapNe)
2004-
gens_back, start_back, hapNe_back = change_time_measure(
2007+
coaltime, coalstart, coalrate = PopulationSizeHistory._change_time_measure(
2008+
gens, start, hapNe
2009+
)
2010+
gens_back, start_back, hapNe_back = PopulationSizeHistory._change_time_measure(
20052011
coaltime, coalstart, coalrate
20062012
)
20072013
assert np.allclose(gens, gens_back)
20082014
assert np.allclose(start, start_back)
20092015
assert np.allclose(hapNe, hapNe_back)
20102016

2011-
def test_piecewise_numerically(self):
2017+
def test_change_time_measure_numerically(self):
20122018
coalrate = np.array([0.001, 0.01, 0.1])
20132019
coalstart = np.array([0, 1, 2])
20142020
coaltime = np.linspace(0, 3, 10)
2015-
gens, _, _ = change_time_measure(coaltime, coalstart, coalrate)
2021+
gens, _, _ = PopulationSizeHistory._change_time_measure(
2022+
coaltime, coalstart, coalrate
2023+
)
20162024
for i in range(gens.size):
20172025
x, _ = scipy.integrate.quad(
20182026
lambda t: 1 / coalrate[np.digitize(t, coalstart) - 1],
20192027
a=0,
20202028
b=coaltime[i],
20212029
)
20222030
assert np.isclose(x, gens[i])
2031+
2032+
def test_to_coalescent_timescale(self):
2033+
demography = PopulationSizeHistory(
2034+
np.array([1000, 2000, 3000]), np.array([500, 2500])
2035+
)
2036+
coaltime = demography.to_coalescent_timescale(np.array([250, 1500]))
2037+
assert np.allclose(coaltime, [0.125, 0.5])
2038+
2039+
def test_to_natural_timescale(self):
2040+
demography = PopulationSizeHistory(
2041+
np.array([1000, 2000, 3000]), np.array([500, 2500])
2042+
)
2043+
time = demography.to_natural_timescale(np.array([0.125, 0.5]))
2044+
assert np.allclose(time, [250, 1500])
2045+
2046+
def test_single_epoch(self):
2047+
for Ne in [10000, np.array([10000])]:
2048+
demography = PopulationSizeHistory(Ne)
2049+
time = demography.to_natural_timescale(np.array([0, 1, 2, 3]))
2050+
assert np.allclose(time, [0.0, 20000, 40000, 60000])
2051+
2052+
def test_bad_arguments(self):
2053+
with pytest.raises(ValueError, match="a numpy array"):
2054+
PopulationSizeHistory([1])
2055+
with pytest.raises(ValueError, match="a numpy array"):
2056+
PopulationSizeHistory(np.array([1, 1]), [1])
2057+
with pytest.raises(ValueError, match="must be greater than 0"):
2058+
PopulationSizeHistory(0)
2059+
with pytest.raises(ValueError, match="must be greater than 0"):
2060+
PopulationSizeHistory(np.array([0, 0]), np.array([1]))
2061+
with pytest.raises(ValueError, match="must be greater than 0"):
2062+
PopulationSizeHistory(np.array([1, 1]), np.array([0]))
2063+
with pytest.raises(ValueError, match="one less than the number"):
2064+
PopulationSizeHistory(np.array([1]), np.array([1]))
2065+
with pytest.raises(ValueError, match="increasing order"):
2066+
PopulationSizeHistory(np.array([1, 1, 1]), np.array([2, 1]))
2067+
demography = PopulationSizeHistory(1)
2068+
for time in [1, [1]]:
2069+
with pytest.raises(ValueError, match="a numpy array"):
2070+
demography.to_natural_timescale(time)
2071+
with pytest.raises(ValueError, match="a numpy array"):
2072+
demography.to_coalescent_timescale(time)

tests/test_inference.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -55,24 +55,9 @@ def test_not_needed_population_size(self):
5555

5656
def test_bad_population_size(self):
5757
ts = utility_functions.two_tree_mutation_ts()
58-
with pytest.raises(ValueError, match="greater than 0"):
59-
tsdate.date(ts, mutation_rate=None, population_size=0)
60-
with pytest.raises(ValueError, match="greater than 0"):
61-
tsdate.date(ts, mutation_rate=None, population_size=-1)
62-
with pytest.raises(ValueError, match="two-dimensional"):
63-
tsdate.date(ts, mutation_rate=None, population_size=np.array([[[1]]]))
64-
with pytest.raises(ValueError, match="two columns"):
65-
tsdate.date(ts, mutation_rate=None, population_size=np.array([[1]]))
66-
with pytest.raises(ValueError, match="nonnegative"):
67-
tsdate.date(ts, mutation_rate=None, population_size=np.array([[-1, 1]]))
68-
with pytest.raises(ValueError, match="positive"):
69-
tsdate.date(ts, mutation_rate=None, population_size=np.array([[0, 0]]))
70-
with pytest.raises(ValueError, match="start at time 0"):
71-
tsdate.date(ts, mutation_rate=None, population_size=np.array([[1, 1]]))
72-
with pytest.raises(ValueError, match="unique and increasing"):
73-
tsdate.date(
74-
ts, mutation_rate=None, population_size=np.array([[0, 1], [0, 1]])
75-
)
58+
for Ne in [0, -1]:
59+
with pytest.raises(ValueError, match="greater than 0"):
60+
tsdate.date(ts, mutation_rate=None, population_size=Ne)
7661

7762
def test_dangling_failure(self):
7863
ts = utility_functions.single_tree_ts_n2_dangling()

tsdate/core.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1032,16 +1032,11 @@ def date(
10321032
# Remove any times associated with mutations
10331033
tables.mutations.time = np.full(tree_sequence.num_mutations, tskit.UNKNOWN_TIME)
10341034
tables.sort()
1035-
population_size_provenance = (
1036-
population_size.tolist()
1037-
if isinstance(population_size, np.ndarray)
1038-
else population_size
1039-
)
1035+
# TODO: record population_size provenance, or record that it is omitted
10401036
provenance.record_provenance(
10411037
tables,
10421038
"date",
10431039
mutation_rate=mutation_rate,
1044-
population_size=population_size_provenance,
10451040
recombination_rate=recombination_rate,
10461041
progress=progress,
10471042
**kwargs,

tsdate/demography.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# MIT License
2+
#
3+
# Copyright (c) 2020 University of Oxford
4+
# Copyright (c) 2021-2023 Tskit Developers
5+
#
6+
# Permission is hereby granted, free of charge, to any person obtaining a copy
7+
# of this software and associated documentation files (the "Software"), to deal
8+
# in the Software without restriction, including without limitation the rights
9+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10+
# copies of the Software, and to permit persons to whom the Software is
11+
# furnished to do so, subject to the following conditions:
12+
#
13+
# The above copyright notice and this permission notice shall be included in
14+
# all copies or substantial portions of the Software.
15+
#
16+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22+
# SOFTWARE.
23+
"""
24+
Routines and classes for manipulating demographic histories in tsdate
25+
"""
26+
import numpy as np
27+
28+
29+
class PopulationSizeHistory:
30+
"""
31+
Stores a piecewise constant population size history and tranforms time from
32+
a natural (generational) scale to a coalescent one
33+
"""
34+
35+
@staticmethod
36+
def _change_time_measure(time_ago, breakpoints, time_measure):
37+
"""
38+
Rescales time given a piecewise-constant time measure. To convert from
39+
generations to coalescent units, the time measure per epoch should be 2 *
40+
effective population size. To convert from coalescent units to
41+
generations, the time measure should be the coalescent rate ``1/(2 * Ne)``.
42+
43+
:param np.ndarray time_ago: An increasing vector of time points
44+
:param np.ndarray breakpoints: Start times of pieces
45+
:param np.ndarray time_measure: Time measure within pieces
46+
47+
:return: Inputs in new time measure
48+
"""
49+
50+
assert np.all(np.diff(breakpoints) > 0.0)
51+
assert np.min(breakpoints) == 0.0
52+
assert np.all(time_ago >= 0.0)
53+
assert np.all(time_measure > 0.0)
54+
assert breakpoints.size == time_measure.size
55+
index = np.searchsorted(breakpoints, time_ago, side="right") - 1
56+
step = np.concatenate(
57+
[
58+
[0.0],
59+
np.cumsum(
60+
breakpoints[1:] * (1.0 / time_measure[:-1] - 1.0 / time_measure[1:])
61+
),
62+
]
63+
)
64+
new_time_ago = time_ago * 1.0 / time_measure[index] + step[index]
65+
new_breakpoints = breakpoints * 1.0 / time_measure + step
66+
new_time_measure = 1.0 / time_measure
67+
return new_time_ago, new_breakpoints, new_time_measure
68+
69+
def __init__(self, population_size, time_breaks=None):
70+
"""
71+
:param np.ndarray population_size: A numpy array containing diploid
72+
population sizes per epoch
73+
:param np.ndarray time_breaks: A sorted numpy array containing time
74+
breaks that divide epochs, measured in units of generations in the
75+
past
76+
"""
77+
78+
if time_breaks is None:
79+
time_breaks = np.array([], dtype=float)
80+
81+
if isinstance(population_size, (int, float)):
82+
if not population_size > 0:
83+
raise ValueError("Population size must be greater than 0")
84+
population_size = np.array([population_size], dtype=float)
85+
else:
86+
if not isinstance(population_size, np.ndarray):
87+
raise ValueError("Population sizes must be in a numpy array")
88+
if not np.all(population_size > 0.0):
89+
raise ValueError("Population sizes must be greater than 0")
90+
if not isinstance(time_breaks, np.ndarray):
91+
raise ValueError("Epoch time breaks must be in a numpy array")
92+
if not time_breaks.size == population_size.size - 1:
93+
raise ValueError(
94+
"The length of the population size array must be one less "
95+
"than the number of epoch time breaks"
96+
)
97+
if time_breaks.size > 0:
98+
if not np.all(time_breaks > 0.0):
99+
raise ValueError("Epoch time breaks must be greater than 0")
100+
if not np.all(np.diff(time_breaks) > 0.0):
101+
raise ValueError(
102+
"Epoch time breaks must be unique and in increasing order"
103+
)
104+
105+
self.time_breaks = np.append([0.0], time_breaks.flatten())
106+
self.population_size = 2 * population_size.flatten()
107+
_, coalescent_breaks, coalescent_rate = self._change_time_measure(
108+
self.time_breaks, self.time_breaks, self.population_size
109+
)
110+
self.coalescent_breaks = coalescent_breaks
111+
self.coalescent_rate = coalescent_rate
112+
113+
def to_natural_timescale(self, coalescent_time_ago):
114+
"""
115+
Convert a vector of times from coalescent units to generations
116+
117+
:param np.ndarray coalescent_time_ago: Times in the past, in coalescent units
118+
:return: Times in the past, in generations
119+
"""
120+
121+
if not isinstance(coalescent_time_ago, np.ndarray):
122+
raise ValueError("Times must be in a numpy array")
123+
time_ago, _, _ = self._change_time_measure(
124+
coalescent_time_ago,
125+
self.coalescent_breaks,
126+
self.coalescent_rate,
127+
)
128+
return time_ago
129+
130+
def to_coalescent_timescale(self, time_ago):
131+
"""
132+
Convert a vector of times from generations to coalescent units
133+
134+
:param np.ndarray time_ago: Times in the past, in generations
135+
:return: Times in the past, in coalescent units
136+
"""
137+
138+
if not isinstance(time_ago, np.ndarray):
139+
raise ValueError("Times must be in a numpy array")
140+
coalescent_time_ago, _, _ = self._change_time_measure(
141+
time_ago,
142+
self.time_breaks,
143+
self.population_size,
144+
)
145+
return coalescent_time_ago
146+
147+
# TODO:
148+
# @staticmethod
149+
# def from_demes(filename):
150+
# """
151+
# Create a `PopulationSizeHistory` instance from a `demes` format YAML
152+
# """

0 commit comments

Comments
 (0)