Skip to content

Commit 786cf12

Browse files
committed
Allow nonfixed sample nodes
1 parent 7999a58 commit 786cf12

File tree

1 file changed

+56
-12
lines changed

1 file changed

+56
-12
lines changed

tsdate/prior.py

Lines changed: 56 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -947,29 +947,55 @@ def gamma_cdf(t_set, alpha, beta):
947947
return np.insert(t_set, 0, 0)
948948

949949

950-
def fill_priors(node_parameters, timepoints, ts, Ne, *, prior_distr, progress=False):
950+
def fill_priors(
951+
node_parameters,
952+
timepoints,
953+
ts,
954+
Ne,
955+
*,
956+
prior_distr,
957+
node_var_override=None,
958+
progress=False,
959+
):
951960
"""
952961
Take the alpha and beta values from the node_parameters array, which contains
953-
one row for each node in the TS (including fixed nodes)
954-
and fill out a NodeGridValues object with the prior values from the
955-
gamma or lognormal distribution with those parameters.
962+
one row for each node in the TS (including fixed nodes, although alpha and beta
963+
are ignored for these nodes) and fill out a NodeGridValues object with the prior
964+
values from the gamma or lognormal distribution with those parameters.
965+
966+
For a description of `node_var_override`, see the parameter description in
967+
the `build_grid` function.
956968
957969
TODO - what if there is an internal fixed node? Should we truncate
958970
"""
959971
if prior_distr == "lognorm":
960972
cdf_func = scipy.stats.lognorm.cdf
961-
main_param = np.sqrt(node_parameters[:, PriorParams.field_index("beta")])
973+
shape_param = np.sqrt(node_parameters[:, PriorParams.field_index("beta")])
962974
scale_param = np.exp(node_parameters[:, PriorParams.field_index("alpha")])
975+
976+
def shape_scale_from_mean_var(mean, var):
977+
a, b = lognorm_approx(mean, var)
978+
return np.sqrt(b), np.exp(a)
979+
963980
elif prior_distr == "gamma":
964981
cdf_func = scipy.stats.gamma.cdf
965-
main_param = node_parameters[:, PriorParams.field_index("alpha")]
966-
scale_param = 1 / node_parameters[:, PriorParams.field_index("beta")]
982+
shape_param = node_parameters[:, PriorParams.field_index("alpha")]
983+
scale_param = 1.0 / node_parameters[:, PriorParams.field_index("beta")]
984+
985+
def shape_scale_from_mean_var(mean, var):
986+
a, b = gamma_approx(mean, var)
987+
return a, 1.0 / b
988+
967989
else:
968990
raise ValueError("prior distribution must be lognorm or gamma")
969-
991+
if node_var_override is None:
992+
node_var_override = {}
970993
datable_nodes = np.ones(ts.num_nodes, dtype=bool)
971994
datable_nodes[ts.samples()] = False
995+
# Mark all nodes in node_var_override as datable
996+
datable_nodes[list(node_var_override.keys())] = True
972997
datable_nodes = np.where(datable_nodes)[0]
998+
973999
prior_times = base.NodeGridValues(
9741000
ts.num_nodes,
9751001
datable_nodes[np.argsort(ts.tables.nodes.time[datable_nodes])].astype(np.int32),
@@ -980,8 +1006,16 @@ def fill_priors(node_parameters, timepoints, ts, Ne, *, prior_distr, progress=Fa
9801006
for node in tqdm(
9811007
datable_nodes, desc="Assign Prior to Each Node", disable=not progress
9821008
):
1009+
if node in node_var_override:
1010+
shape, scale = shape_scale_from_mean_var(
1011+
mean=ts.node(node).time,
1012+
var=node_var_override[node],
1013+
)
1014+
else:
1015+
shape = shape_param[node]
1016+
scale = scale_param[node]
9831017
with np.errstate(divide="ignore", invalid="ignore"):
984-
prior_node = cdf_func(timepoints, main_param[node], scale=scale_param[node])
1018+
prior_node = cdf_func(timepoints, shape, scale=scale)
9851019
# force age to be less than max value
9861020
prior_node = np.divide(prior_node, np.max(prior_node))
9871021
# prior in each epoch
@@ -994,7 +1028,7 @@ def fill_priors(node_parameters, timepoints, ts, Ne, *, prior_distr, progress=Fa
9941028
def _truncate_priors(ts, priors, progress=False):
9951029
"""
9961030
Truncate priors for the nodes listed in truncate_nodes (or all nonfixed nodes
997-
if truncate_nodes in None) so they conform to the age of fixed nodes in the tree
1031+
if truncate_nodes is None) so they conform to the age of fixed nodes in the tree
9981032
sequence
9991033
"""
10001034
tables = ts.tables
@@ -1060,6 +1094,7 @@ def build_grid(
10601094
prior_distribution="lognorm",
10611095
allow_historical_samples=None,
10621096
truncate_priors=None,
1097+
node_var_override=None,
10631098
eps=1e-6,
10641099
progress=False,
10651100
):
@@ -1094,6 +1129,13 @@ def build_grid(
10941129
priors of their direct ancestor nodes so that the probability of being younger
10951130
than the oldest descendant sample is zero. If the tree sequence is trustworthy
10961131
this should give better restults. Default: `True`
1132+
:param dict node_var_override: is a dict mapping node IDs to a variance value.
1133+
Any nodes listed here will be treated as non-fixed nodes whose prior is not
1134+
calculated from the conditional coalescent but instead are allocated a prior
1135+
whose mean is thenode time in the tree sequence and whose variance is the
1136+
value in this dictionary. This allows sample nodes to be treated as nonfixed
1137+
nodes, and therefore dated. If ``None`` (default) then all sample nodes are
1138+
treated as occurring ata fixed time (as if this were an empty dict).
10971139
:param float eps: Specify minimum distance separating points in the time grid. Also
10981140
specifies the error factor in time difference calculations. Default: 1e-6
10991141
:return: A prior object to pass to tsdate.date() containing prior values for
@@ -1154,16 +1196,18 @@ def build_grid(
11541196
tree_sequence,
11551197
Ne,
11561198
prior_distr=prior_distribution,
1199+
node_var_override=node_var_override,
11571200
progress=progress,
11581201
)
1159-
if np.any(tree_sequence.tables.nodes.time[tree_sequence.samples()] != 0):
1202+
tables = tree_sequence.tables
1203+
if np.any(tables.nodes.time[tree_sequence.samples()] > 0):
11601204
if not allow_historical_samples:
11611205
raise ValueError(
11621206
"There are samples at non-zero times, invalidating the conditional "
11631207
"coalescent prior. You can set allow_historical_samples=True to carry "
11641208
"on regardless, calculating a prior as if all samples were "
11651209
"contemporaneous (reasonable if you only have a few ancient samples)"
11661210
)
1167-
if truncate_priors:
1211+
if np.any(tables.nodes.time[priors.fixed_node_ids()] > 0) and truncate_priors:
11681212
priors = _truncate_priors(tree_sequence, priors, progress=progress)
11691213
return priors

0 commit comments

Comments
 (0)