Skip to content

Commit 4a1890c

Browse files
authored
Merge pull request #283 from hyanwong/popsizeprov
Save PopulationSizeHistory params
2 parents f9a2c57 + 868e177 commit 4a1890c

File tree

4 files changed

+87
-57
lines changed

4 files changed

+87
-57
lines changed

tests/test_functions.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2147,16 +2147,12 @@ def test_change_time_measure_numerically(self):
21472147
assert np.isclose(x, gens[i])
21482148

21492149
def test_to_coalescent_timescale(self):
2150-
demography = PopulationSizeHistory(
2151-
np.array([1000, 2000, 3000]), np.array([500, 2500])
2152-
)
2150+
demography = PopulationSizeHistory([1000, 2000, 3000], [500, 2500])
21532151
coaltime = demography.to_coalescent_timescale(np.array([250, 1500]))
21542152
assert np.allclose(coaltime, [0.125, 0.5])
21552153

21562154
def test_to_natural_timescale(self):
2157-
demography = PopulationSizeHistory(
2158-
np.array([1000, 2000, 3000]), np.array([500, 2500])
2159-
)
2155+
demography = PopulationSizeHistory([1000, 2000, 3000], [500, 2500])
21602156
time = demography.to_natural_timescale(np.array([0.125, 0.5]))
21612157
assert np.allclose(time, [250, 1500])
21622158

@@ -2169,9 +2165,7 @@ def test_single_epoch(self):
21692165
def test_moments_numerically(self):
21702166
alpha = 2.8
21712167
beta = 1.7
2172-
demography = PopulationSizeHistory(
2173-
np.array([1000, 2000, 3000]), np.array([500, 2500])
2174-
)
2168+
demography = PopulationSizeHistory([1000, 2000, 3000], [500, 2500])
21752169
numer_mn, _ = scipy.integrate.quad(
21762170
lambda t: demography.to_natural_timescale(np.array([t]))
21772171
* scipy.stats.gamma.pdf(t, alpha, scale=1 / beta),
@@ -2192,10 +2186,16 @@ def test_moments_numerically(self):
21922186
assert np.isclose(numer_va, analy_va)
21932187

21942188
def test_bad_arguments(self):
2195-
with pytest.raises(ValueError, match="a numpy array"):
2196-
PopulationSizeHistory([1])
2197-
with pytest.raises(ValueError, match="a numpy array"):
2198-
PopulationSizeHistory(np.array([1, 1]), [1])
2189+
with pytest.raises(ValueError, match="greater than 0"):
2190+
PopulationSizeHistory([None])
2191+
with pytest.raises(ValueError, match="finite"):
2192+
PopulationSizeHistory(np.inf)
2193+
with pytest.raises(ValueError, match="a numpy float array"):
2194+
PopulationSizeHistory(["a"])
2195+
with pytest.raises(TypeError, match="a numpy float array"):
2196+
PopulationSizeHistory([{}])
2197+
with pytest.raises(TypeError, match="a numpy float array"):
2198+
PopulationSizeHistory(np.array([1, 1]), [{}])
21992199
with pytest.raises(ValueError, match="must be greater than 0"):
22002200
PopulationSizeHistory(0)
22012201
with pytest.raises(ValueError, match="must be greater than 0"):

tests/test_provenance.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -54,26 +54,29 @@ def test_date_params_recorded(self):
5454
assert np.isclose(rec["parameters"]["mutation_rate"], mu)
5555
assert np.isclose(rec["parameters"]["population_size"], Ne)
5656

57-
@pytest.mark.skip(
58-
reason="Not implemented yet: https://github.com/tskit-dev/tsdate/issues/274"
57+
@pytest.mark.parametrize(
58+
"popdict",
59+
[
60+
{"population_size": [1, 2, 3], "time_breaks": [1, 1.2]},
61+
{"population_size": [123]},
62+
],
5963
)
60-
def test_date_popsizehist_recorded(self):
64+
def test_date_popsizehist_recorded(self, popdict):
6165
ts = utility_functions.single_tree_ts_n2()
6266
mu = 0.123
63-
popsize = tsdate.demography.PopulationSizeHistory(
64-
np.array([1, 1]), np.array([1])
65-
)
66-
dated_ts = tsdate.date(ts, population_size=popsize, mutation_rate=mu)
67-
rec = json.loads(dated_ts.provenance(-1).record)
68-
assert np.isclose(rec["parameters"]["mutation_rate"], mu)
69-
assert "population_size" in rec["parameters"]
70-
# TODO: check that the population size history is recorded correctly
71-
# https://github.com/tskit-dev/tsdate/issues/274
72-
assert np.isclose(
73-
# This is wrong - left in for now to show the sort of thing we want
74-
(np.array(rec["parameters"]["population_size"])),
75-
popsize,
76-
)
67+
for use_class in (False, True):
68+
if use_class:
69+
popsize = tsdate.demography.PopulationSizeHistory(**popdict)
70+
else:
71+
popsize = popdict
72+
dated_ts = tsdate.date(ts, population_size=popsize, mutation_rate=mu)
73+
rec = json.loads(dated_ts.provenance(-1).record)
74+
assert np.isclose(rec["parameters"]["mutation_rate"], mu)
75+
assert "population_size" in rec["parameters"]
76+
popsz = rec["parameters"]["population_size"]
77+
assert len(popsz) == len(popdict)
78+
for param, val in popdict.items():
79+
assert np.all(np.isclose(val, popsz[param]))
7780

7881
def test_preprocess_cmd_recorded(self):
7982
ts = utility_functions.ts_w_data_desert(40, 60, 100)

tsdate/core.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939

4040
from . import approx
4141
from . import base
42+
from . import demography
4243
from . import prior
4344
from . import provenance
4445

@@ -1169,10 +1170,12 @@ def date(
11691170
one whose non-sample nodes are undated.
11701171
:param PopulationSizeHistory population_size: The estimated (diploid) effective
11711172
population size used to construct the (default) conditional coalescent
1172-
prior. This may be a single value (for a population with constant size), or
1173-
a :class:`PopulationSizeHistory` object (for a population with time-varying
1174-
size). This is used when ``priors`` is ``None``. Conversely, if ``priors``
1175-
is not ``None``, no ``population_size`` value should be given.
1173+
prior. For a population with constant size, this can be given as a single
1174+
value. For a population with time-varying size, this can be given directly as
1175+
a :class:`PopulationSizeHistory` object or a parameter dictionary passed
1176+
to initialise a class:`PopulationSizeHistory` object. This is used when
1177+
``priors`` is ``None``. Conversely, if ``priors`` is not ``None``, no
1178+
``population_size`` value should be specified.
11761179
:param float mutation_rate: The estimated mutation rate per unit of genome per
11771180
unit time. If provided, the dating algorithm will use a mutation rate clock to
11781181
help estimate node dates. Default: ``None``
@@ -1226,6 +1229,10 @@ def date(
12261229
)
12271230
else:
12281231
population_size = Ne
1232+
1233+
if isinstance(population_size, dict):
1234+
population_size = demography.PopulationSizeHistory(**population_size)
1235+
12291236
if method == "variational_gamma":
12301237
tree_sequence, dates, posteriors, timepoints, eps, nds = variational_dates(
12311238
tree_sequence,
@@ -1261,8 +1268,8 @@ def date(
12611268
)
12621269
if isinstance(population_size, (int, float)):
12631270
params["population_size"] = population_size
1264-
else:
1265-
params["population_size"] = "TODO: PopulationSizeHistory object"
1271+
elif isinstance(population_size, demography.PopulationSizeHistory):
1272+
params["population_size"] = population_size.as_dict()
12661273
provenance.record_provenance(
12671274
tables,
12681275
"date",

tsdate/demography.py

Lines changed: 41 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -69,39 +69,48 @@ def _change_time_measure(time_ago, breakpoints, time_measure):
6969

7070
def __init__(self, population_size, time_breaks=None):
7171
"""
72-
:param np.ndarray population_size: A numpy array containing diploid
72+
:param array_like population_size: An array containing diploid
7373
population sizes per epoch
74-
:param np.ndarray time_breaks: A sorted numpy array containing time
74+
:param array_like time_breaks: A sorted array containing time
7575
breaks that divide epochs, measured in units of generations in the
7676
past
7777
"""
7878

7979
if time_breaks is None:
80-
time_breaks = np.array([], dtype=float)
80+
time_breaks = []
8181

8282
if isinstance(population_size, (int, float)):
83-
if not population_size > 0:
84-
raise ValueError("Population size must be greater than 0")
8583
population_size = np.array([population_size], dtype=float)
8684
else:
87-
if not isinstance(population_size, np.ndarray):
88-
raise ValueError("Population sizes must be in a numpy array")
89-
if not np.all(population_size > 0.0):
90-
raise ValueError("Population sizes must be greater than 0")
91-
if not isinstance(time_breaks, np.ndarray):
92-
raise ValueError("Epoch time breaks must be in a numpy array")
93-
if not time_breaks.size == population_size.size - 1:
85+
try:
86+
population_size = np.array(population_size, dtype=float)
87+
except (ValueError, TypeError) as e:
88+
raise e.__class__(
89+
"Population sizes must be convertable to a numpy float array"
90+
) from e
91+
if not np.all(population_size > 0.0):
92+
raise ValueError("Population sizes must be greater than 0")
93+
if not np.all(np.isfinite(population_size)):
94+
raise ValueError("Population sizes must be finite")
95+
96+
try:
97+
time_breaks = np.array(time_breaks, dtype=float)
98+
except (ValueError, TypeError) as e:
99+
raise e.__class__(
100+
"Time breaks must be convertable to a numpy float array"
101+
) from e
102+
if not time_breaks.size == population_size.size - 1:
103+
raise ValueError(
104+
"The length of the population size array must be one less "
105+
"than the number of epoch time breaks"
106+
)
107+
if time_breaks.size > 0:
108+
if not np.all(time_breaks > 0.0):
109+
raise ValueError("Epoch time breaks must be greater than 0")
110+
if not np.all(np.diff(time_breaks) > 0.0):
94111
raise ValueError(
95-
"The length of the population size array must be one less "
96-
"than the number of epoch time breaks"
112+
"Epoch time breaks must be unique and in increasing order"
97113
)
98-
if time_breaks.size > 0:
99-
if not np.all(time_breaks > 0.0):
100-
raise ValueError("Epoch time breaks must be greater than 0")
101-
if not np.all(np.diff(time_breaks) > 0.0):
102-
raise ValueError(
103-
"Epoch time breaks must be unique and in increasing order"
104-
)
105114

106115
self.time_breaks = np.append([0.0], time_breaks.flatten())
107116
self.population_size = 2 * population_size.flatten()
@@ -111,6 +120,17 @@ def __init__(self, population_size, time_breaks=None):
111120
self.coalescent_breaks = coalescent_breaks
112121
self.coalescent_rate = coalescent_rate
113122

123+
def as_dict(self):
124+
"""
125+
Return the population size history as a dictionary of parameters
126+
that can be used to initialise a new object
127+
"""
128+
ret_val = {"population_size": list(self.population_size / 2)}
129+
assert self.time_breaks[0] == 0.0
130+
if len(self.time_breaks) > 1:
131+
ret_val["time_breaks"] = list(self.time_breaks[1:])
132+
return ret_val
133+
114134
def to_natural_timescale(self, coalescent_time_ago):
115135
"""
116136
Convert a vector of times from coalescent units to generations

0 commit comments

Comments
 (0)