Skip to content

Commit eef71d9

Browse files
committed
WIP
1 parent e11cf83 commit eef71d9

File tree

2 files changed

+31
-25
lines changed

2 files changed

+31
-25
lines changed

tsdate/core.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -258,15 +258,28 @@ def get_mut_lik_fixed_node(self, edge):
258258

259259
mutations_on_edge = self.mut_edges[edge.id]
260260
child_time = self.ts.node(edge.child).time
261-
#assert child_time == 0
262-
# Temporary hack - we should really take a more precise likelihood
263-
return self._lik(
264-
mutations_on_edge,
265-
edge.span,
266-
self.timediff,
267-
self.mut_rate,
268-
normalize=self.normalize,
269-
)
261+
if child_time == 0:
262+
return self._lik(
263+
mutations_on_edge,
264+
edge.span,
265+
self.timediff,
266+
self.mut_rate,
267+
normalize=self.normalize,
268+
)
269+
else:
270+
timediff = self.timepoints - child_time + 1e-8
271+
# Temporary hack - we should really take a more precise likelihood
272+
likelihood = self._lik(
273+
mutations_on_edge,
274+
edge.span,
275+
timediff,
276+
self.mut_rate,
277+
normalize=False,
278+
)
279+
likelihood[timediff < 0] = 0
280+
281+
return likelihood
282+
270283

271284
def get_mut_lik_lower_tri(self, edge):
272285
"""
@@ -677,10 +690,13 @@ def inside_pass(self, *, normalize=True, cache_inside=False, progress=None):
677690
)
678691
edge_lik = self.lik.get_inside(daughter_val, edge)
679692
val = self.lik.combine(val, edge_lik)
693+
if np.all(val == 0):
694+
raise ValueError
680695
if cache_inside:
681696
g_i[edge.id] = edge_lik
682697
norm[parent] = np.max(val) if normalize else 1
683698
inside[parent] = self.lik.reduce(val, norm[parent])
699+
684700
if cache_inside:
685701
self.g_i = self.lik.reduce(g_i, norm[self.ts.tables.edges.child, None])
686702
# Keep the results in this object
@@ -732,10 +748,10 @@ def outside_pass(
732748
if ignore_oldest_root:
733749
if edge.parent == self.ts.num_nodes - 1:
734750
continue
735-
#if edge.parent in self.fixednodes:
736-
# raise RuntimeError(
737-
# "Fixed nodes cannot currently be parents in the TS"
738-
# )
751+
if edge.parent in self.fixednodes:
752+
raise RuntimeError(
753+
"Fixed nodes cannot currently be parents in the TS"
754+
)
739755
# Geometric scaling works exactly for all nodes fixed in graph
740756
# but is an approximation when times are unknown.
741757
spanfrac = edge.span / self.spans[child]
@@ -1065,10 +1081,6 @@ def get_dates(
10651081
10661082
:return: tuple(mn_post, posterior, timepoints, eps, nodes_to_date)
10671083
"""
1068-
# Stuff yet to be implemented. These can be deleted once fixed
1069-
#for sample in tree_sequence.samples():
1070-
# if tree_sequence.node(sample).time != 0:
1071-
# raise NotImplementedError("Samples must all be at time 0")
10721084
fixed_nodes = set(tree_sequence.samples())
10731085

10741086
# Default to not creating approximate priors unless ts has > 1000 samples

tsdate/prior.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,6 +1031,8 @@ def build_grid(
10311031
inference and a discretised time grid
10321032
:rtype: base.NodeGridValues Object
10331033
"""
1034+
#tree_sequence = tree_sequence.simplify(tree_sequence.samples())
1035+
10341036
if Ne <= 0:
10351037
raise ValueError("Parameter 'Ne' must be greater than 0")
10361038
if approximate_priors:
@@ -1042,12 +1044,6 @@ def build_grid(
10421044
"Can't set approx_prior_size if approximate_prior is False"
10431045
)
10441046

1045-
#contmpr_ts, node_map = util.reduce_to_contemporaneous(tree_sequence)
1046-
#if contmpr_ts.num_nodes != tree_sequence.num_nodes:
1047-
# raise ValueError(
1048-
# "Passed tree sequence is not simplified and/or contains "
1049-
# "noncontemporaneous samples"
1050-
# )
10511047
span_data = SpansBySamples(tree_sequence, progress=progress)
10521048

10531049
base_priors = ConditionalCoalescentTimes(
@@ -1079,8 +1075,6 @@ def build_grid(
10791075
raise ValueError("time_slices must be an integer or a numpy array of floats")
10801076

10811077
prior_params = base_priors.get_mixture_prior_params(span_data)
1082-
# Map the nodes in the prior params back to the node ids in the original ts
1083-
#prior_params = prior_params_contmpr[node_map, :]
10841078
# Set all fixed nodes (i.e. samples) to have 0 variance
10851079
priors = fill_priors(
10861080
prior_params,

0 commit comments

Comments
 (0)