Skip to content

Commit 98cc77e

Browse files
committed
Add natural argument to to_gamma
1 parent 5ef0198 commit 98cc77e

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

tsdate/core.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -542,14 +542,18 @@ def __init__(
542542
self.identity_constant.flags.writeable = False
543543
self.timepoints.flags.writeable = False
544544

545-
def to_gamma(self, edge):
545+
def to_gamma(self, edge, natural=False):
546546
"""
547547
Return the shape and rate parameters of the (gamma) posterior of edge
548-
length, given an improper (constant) prior.
548+
length, given an improper (constant) prior. If ``natural`` is ``True``,
549+
return the natural parameterization instead.
549550
"""
550551
y = self.mut_edges[edge.id]
551552
mu = edge.span * self.mut_rate
552-
return np.array([y + 1, mu])
553+
if natural:
554+
return np.array([y, mu])
555+
else:
556+
return np.array([y + 1, mu])
553557

554558
@staticmethod
555559
def get_mut_edges(ts):
@@ -973,7 +977,7 @@ def __init__(self, *args, **kwargs):
973977
# and can be incorporated into the posterior beforehand
974978
for edge in self.ts.edges():
975979
if edge.child in self.fixednodes:
976-
self.parent_message[edge.id] = self.lik.to_gamma(edge)
980+
self.parent_message[edge.id] = self.lik.to_gamma(edge, natural=False)
977981
self.posterior[edge.parent] = self.lik.combine(
978982
self.posterior[edge.parent], self.parent_message[edge.id]
979983
)
@@ -996,8 +1000,7 @@ def propagate(self, *, edges, progress=None):
9961000
continue
9971001
if edge.parent in self.fixednodes:
9981002
raise ValueError("Internal nodes can not be fixed in EP algorithm")
999-
edge_lik = self.lik.to_gamma(edge)
1000-
edge_lik += np.array([-1.0, 0.0]) # to Poisson, TODO cleanup
1003+
edge_lik = self.lik.to_gamma(edge, natural=True)
10011004
# Get the cavity posteriors: that is, the rest of the approximation
10021005
# without the factor for this edge. This only involves the variational
10031006
# parameters for the parent and child on the edge.
@@ -1181,13 +1184,13 @@ def date(
11811184
:param int num_threads: The number of threads to use. A simpler unthreaded algorithm
11821185
is used unless this is >= 1. Default: None
11831186
:param string method: What estimation method to use: can be
1184-
"EP" (variational approximation, empirically most accurate),
1187+
"variational_gamma" (variational approximation, empirically most accurate),
11851188
"inside_outside" (empirically better, theoretically problematic) or
11861189
"maximization" (worse empirically, especially with gamma approximated priors,
11871190
but theoretically robust). If ``None`` (default) use "inside_outside"
11881191
:param string probability_space: Should the internal algorithm save probabilities in
11891192
"logarithmic" (slower, less liable to to overflow) or "linear" space (fast, may
1190-
overflow). Does not apply to method ``EP``. Default: "logarithmic"
1193+
overflow). Does not apply to method ``variational_gamma``. Default: "logarithmic"
11911194
:param bool ignore_oldest_root: Should the oldest root in the tree sequence be
11921195
ignored in the outside algorithm (if "inside_outside" is used as the method).
11931196
Ignoring outside root provides greater stability when dating tree sequences
@@ -1212,7 +1215,7 @@ def date(
12121215
)
12131216
else:
12141217
population_size = Ne
1215-
if method == "EP":
1218+
if method == "variational_gamma":
12161219
tree_sequence, dates, posteriors, timepoints, eps, nds = variational_dates(
12171220
tree_sequence,
12181221
population_size=population_size,

0 commit comments

Comments
 (0)