Skip to content

Commit 252529c

Browse files
committed
Allow nonfixed sample nodes
1 parent e940800 commit 252529c

File tree

1 file changed

+56
-12
lines changed

1 file changed

+56
-12
lines changed

tsdate/prior.py

Lines changed: 56 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -952,29 +952,55 @@ def gamma_cdf(t_set, alpha, beta):
952952
return np.insert(t_set, 0, 0)
953953

954954

955-
def fill_priors(node_parameters, timepoints, ts, Ne, *, prior_distr, progress=False):
955+
def fill_priors(
956+
node_parameters,
957+
timepoints,
958+
ts,
959+
Ne,
960+
*,
961+
prior_distr,
962+
node_var_override=None,
963+
progress=False,
964+
):
956965
"""
957966
Take the alpha and beta values from the node_parameters array, which contains
958-
one row for each node in the TS (including fixed nodes)
959-
and fill out a NodeGridValues object with the prior values from the
960-
gamma or lognormal distribution with those parameters.
967+
one row for each node in the TS (including fixed nodes, although alpha and beta
968+
are ignored for these nodes) and fill out a NodeGridValues object with the prior
969+
values from the gamma or lognormal distribution with those parameters.
970+
971+
For a description of `node_var_override`, see the parameter description in
972+
the `build_grid` function.
961973
962974
TODO - what if there is an internal fixed node? Should we truncate
963975
"""
964976
if prior_distr == "lognorm":
965977
cdf_func = scipy.stats.lognorm.cdf
966-
main_param = np.sqrt(node_parameters[:, PriorParams.field_index("beta")])
978+
shape_param = np.sqrt(node_parameters[:, PriorParams.field_index("beta")])
967979
scale_param = np.exp(node_parameters[:, PriorParams.field_index("alpha")])
980+
981+
def shape_scale_from_mean_var(mean, var):
982+
a, b = lognorm_approx(mean, var)
983+
return np.sqrt(b), np.exp(a)
984+
968985
elif prior_distr == "gamma":
969986
cdf_func = scipy.stats.gamma.cdf
970-
main_param = node_parameters[:, PriorParams.field_index("alpha")]
971-
scale_param = 1 / node_parameters[:, PriorParams.field_index("beta")]
987+
shape_param = node_parameters[:, PriorParams.field_index("alpha")]
988+
scale_param = 1.0 / node_parameters[:, PriorParams.field_index("beta")]
989+
990+
def shape_scale_from_mean_var(mean, var):
991+
a, b = gamma_approx(mean, var)
992+
return a, 1.0 / b
993+
972994
else:
973995
raise ValueError("prior distribution must be lognorm or gamma")
974-
996+
if node_var_override is None:
997+
node_var_override = {}
975998
datable_nodes = np.ones(ts.num_nodes, dtype=bool)
976999
datable_nodes[ts.samples()] = False
1000+
# Mark all nodes in node_var_override as datable
1001+
datable_nodes[list(node_var_override.keys())] = True
9771002
datable_nodes = np.where(datable_nodes)[0]
1003+
9781004
prior_times = base.NodeGridValues(
9791005
ts.num_nodes,
9801006
datable_nodes[np.argsort(ts.tables.nodes.time[datable_nodes])].astype(np.int32),
@@ -985,8 +1011,16 @@ def fill_priors(node_parameters, timepoints, ts, Ne, *, prior_distr, progress=Fa
9851011
for node in tqdm(
9861012
datable_nodes, desc="Assign Prior to Each Node", disable=not progress
9871013
):
1014+
if node in node_var_override:
1015+
shape, scale = shape_scale_from_mean_var(
1016+
mean=ts.node(node).time,
1017+
var=node_var_override[node],
1018+
)
1019+
else:
1020+
shape = shape_param[node]
1021+
scale = scale_param[node]
9881022
with np.errstate(divide="ignore", invalid="ignore"):
989-
prior_node = cdf_func(timepoints, main_param[node], scale=scale_param[node])
1023+
prior_node = cdf_func(timepoints, shape, scale=scale)
9901024
# force age to be less than max value
9911025
prior_node = np.divide(prior_node, np.max(prior_node))
9921026
# prior in each epoch
@@ -999,7 +1033,7 @@ def fill_priors(node_parameters, timepoints, ts, Ne, *, prior_distr, progress=Fa
9991033
def _truncate_priors(ts, priors, progress=False):
10001034
"""
10011035
Truncate priors for the nodes listed in truncate_nodes (or all nonfixed nodes
1002-
if truncate_nodes in None) so they conform to the age of fixed nodes in the tree
1036+
if truncate_nodes is None) so they conform to the age of fixed nodes in the tree
10031037
sequence
10041038
"""
10051039
tables = ts.tables
@@ -1065,6 +1099,7 @@ def build_grid(
10651099
prior_distribution="lognorm",
10661100
allow_historical_samples=None,
10671101
truncate_priors=None,
1102+
node_var_override=None,
10681103
eps=1e-6,
10691104
# Parameters below undocumented
10701105
progress=False,
@@ -1100,6 +1135,13 @@ def build_grid(
11001135
priors of their direct ancestor nodes so that the probability of being younger
11011136
than the oldest descendant sample is zero. If the tree sequence is trustworthy
11021137
this should give better restults. Default: `True`
1138+
:param dict node_var_override: is a dict mapping node IDs to a variance value.
1139+
Any nodes listed here will be treated as non-fixed nodes whose prior is not
1140+
calculated from the conditional coalescent but instead are allocated a prior
1141+
whose mean is thenode time in the tree sequence and whose variance is the
1142+
value in this dictionary. This allows sample nodes to be treated as nonfixed
1143+
nodes, and therefore dated. If ``None`` (default) then all sample nodes are
1144+
treated as occurring ata fixed time (as if this were an empty dict).
11031145
:param float eps: Specify minimum distance separating points in the time grid. Also
11041146
specifies the error factor in time difference calculations. Default: 1e-6
11051147
:return: A prior object to pass to tsdate.date() containing prior values for
@@ -1160,16 +1202,18 @@ def build_grid(
11601202
tree_sequence,
11611203
Ne,
11621204
prior_distr=prior_distribution,
1205+
node_var_override=node_var_override,
11631206
progress=progress,
11641207
)
1165-
if np.any(tree_sequence.tables.nodes.time[tree_sequence.samples()] != 0):
1208+
tables = tree_sequence.tables
1209+
if np.any(tables.nodes.time[tree_sequence.samples()] > 0):
11661210
if not allow_historical_samples:
11671211
raise ValueError(
11681212
"There are samples at non-zero times, invalidating the conditional "
11691213
"coalescent prior. You can set allow_historical_samples=True to carry "
11701214
"on regardless, calculating a prior as if all samples were "
11711215
"contemporaneous (reasonable if you only have a few ancient samples)"
11721216
)
1173-
if truncate_priors:
1217+
if np.any(tables.nodes.time[priors.fixed_node_ids()] > 0) and truncate_priors:
11741218
priors = _truncate_priors(tree_sequence, priors, progress=progress)
11751219
return priors

0 commit comments

Comments
 (0)