Skip to content

Commit 09b5624

Browse files
committed
Change name and correct identify constant
1 parent 5c8aed4 commit 09b5624

File tree

2 files changed

+27
-22
lines changed

2 files changed

+27
-22
lines changed

tsdate/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -705,14 +705,14 @@ def inside_pass(self, *, normalize=True, cache_inside=False, progress=None):
705705
raise ValueError
706706
if cache_inside:
707707
g_i[edge.id] = edge_lik
708-
norm[parent] = np.max(val) if normalize else 1
708+
norm[parent] = np.max(val) if normalize else self.lik.identity_constant
709709
inside[parent] = self.lik.reduce(val, norm[parent])
710710
to_visit[parent] = False
711711

712712
# There may be nodes that are not parents but are also not fixed (e.g.
713713
# undated sample nodes). These need an identity normalization constant
714714
for unfixed_unvisited in np.where(to_visit)[0]:
715-
norm[unfixed_unvisited] = 1
715+
norm[unfixed_unvisited] = self.lik.identity_constant
716716

717717
if cache_inside:
718718
self.g_i = self.lik.reduce(g_i, norm[self.ts.tables.edges.child, None])

tsdate/prior.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -960,7 +960,7 @@ def fill_priors(
960960
Ne,
961961
*,
962962
prior_distr,
963-
node_var_override=None,
963+
nonfixed_sample_var=None,
964964
progress=False,
965965
):
966966
"""
@@ -969,7 +969,7 @@ def fill_priors(
969969
are ignored for these nodes) and fill out a NodeGridValues object with the prior
970970
values from the gamma or lognormal distribution with those parameters.
971971
972-
For a description of `node_var_override`, see the parameter description in
972+
For a description of `nonfixed_sample_var`, see the parameter description in
973973
the `build_grid` function.
974974
975975
TODO - what if there is an internal fixed node? Should we truncate
@@ -994,12 +994,16 @@ def shape_scale_from_mean_var(mean, var):
994994

995995
else:
996996
raise ValueError("prior distribution must be lognorm or gamma")
997-
if node_var_override is None:
998-
node_var_override = {}
997+
samples = ts.samples()
998+
if nonfixed_sample_var is None:
999+
nonfixed_sample_var = {}
1000+
for u in nonfixed_sample_var.keys():
1001+
if u not in samples:
1002+
raise ValueError(f"Node {u} in 'nonfixed_sample_var' is not a sample")
9991003
datable_nodes = np.ones(ts.num_nodes, dtype=bool)
1000-
datable_nodes[ts.samples()] = False
1001-
# Mark all nodes in node_var_override as datable
1002-
datable_nodes[list(node_var_override.keys())] = True
1004+
datable_nodes[samples] = False
1005+
# Mark all nodes in nonfixed_sample_var as datable
1006+
datable_nodes[list(nonfixed_sample_var.keys())] = True
10031007
datable_nodes = np.where(datable_nodes)[0]
10041008

10051009
prior_times = base.NodeGridValues(
@@ -1012,10 +1016,10 @@ def shape_scale_from_mean_var(mean, var):
10121016
for node in tqdm(
10131017
datable_nodes, desc="Assign Prior to Each Node", disable=not progress
10141018
):
1015-
if node in node_var_override:
1019+
if node in nonfixed_sample_var:
10161020
shape, scale = shape_scale_from_mean_var(
10171021
mean=ts.node(node).time,
1018-
var=node_var_override[node],
1022+
var=nonfixed_sample_var[node],
10191023
)
10201024
else:
10211025
shape = shape_param[node]
@@ -1098,7 +1102,7 @@ def build_grid(
10981102
prior_distribution="lognorm",
10991103
allow_historical_samples=None,
11001104
truncate_priors=None,
1101-
node_var_override=None,
1105+
nonfixed_sample_var=None,
11021106
eps=1e-6,
11031107
# Parameters below undocumented
11041108
progress=False,
@@ -1127,20 +1131,21 @@ def build_grid(
11271131
gamma distribution (slightly faster, but a poorer fit for recent nodes).
11281132
Default: "lognorm"
11291133
:param bool allow_historical_samples: should we allow historical samples (i.e. at
1130-
times > 0. This invalidates the assumptions of the conditional coalescent, but
1134+
times > 0). This invalidates the assumptions of the conditional coalescent, but
11311135
may be acceptable if the historical samples are recent or if there are many
1132-
contemporaneous samples. Default: `False`
1136+
contemporaneous samples. Default: ``False``
11331137
:param bool truncate_priors: If there are historical samples, should we truncate the
1134-
priors of their direct ancestor nodes so that the probability of being younger
1135-
than the oldest descendant sample is zero. If the tree sequence is trustworthy
1136-
this should give better restults. Default: `True`
1137-
:param dict node_var_override: is a dict mapping node IDs to a variance value.
1138-
Any nodes listed here will be treated as non-fixed nodes whose prior is not
1139-
calculated from the conditional coalescent but instead are allocated a prior
1138+
priors of all nodes which are their ancestors so that the probability of being
1139+
younger than the oldest descendant sample is zero. As long as historical
1140+
samples do not have ancestors that have been misassigned in the tree sequence
1141+
topology, this should give better results. Default: ``True``
1142+
:param dict nonfixed_sample_var: is a dict mapping sample node IDs to a variance
1143+
value. Any nodes listed here will be treated as non-fixed nodes whose prior is
1144+
not calculated from the conditional coalescent but instead are allocated a prior
11401145
whose mean is the node time in the tree sequence and whose variance is the
11411146
value in this dictionary. This allows sample nodes to be treated as nonfixed
11421147
nodes, and therefore dated. If ``None`` (default) then all sample nodes are
1143-
treated as occurring ata fixed time (as if this were an empty dict).
1148+
treated as occurring at a fixed time (as if this were an empty dict).
11441149
:param float eps: Specify minimum distance separating points in the time grid. Also
11451150
specifies the error factor in time difference calculations. Default: 1e-6
11461151
:return: A prior object to pass to tsdate.date() containing prior values for
@@ -1201,7 +1206,7 @@ def build_grid(
12011206
tree_sequence,
12021207
Ne,
12031208
prior_distr=prior_distribution,
1204-
node_var_override=node_var_override,
1209+
nonfixed_sample_var=nonfixed_sample_var,
12051210
progress=progress,
12061211
)
12071212
if np.any(tree_sequence.nodes_time[tree_sequence.samples()] > 0):

0 commit comments

Comments
 (0)