@@ -258,15 +258,28 @@ def get_mut_lik_fixed_node(self, edge):
258
258
259
259
mutations_on_edge = self .mut_edges [edge .id ]
260
260
child_time = self .ts .node (edge .child ).time
261
- #assert child_time == 0
262
- # Temporary hack - we should really take a more precise likelihood
263
- return self ._lik (
264
- mutations_on_edge ,
265
- edge .span ,
266
- self .timediff ,
267
- self .mut_rate ,
268
- normalize = self .normalize ,
269
- )
261
+ if child_time == 0 :
262
+ return self ._lik (
263
+ mutations_on_edge ,
264
+ edge .span ,
265
+ self .timediff ,
266
+ self .mut_rate ,
267
+ normalize = self .normalize ,
268
+ )
269
+ else :
270
+ timediff = self .timepoints - child_time + 1e-8
271
+ # Temporary hack - we should really take a more precise likelihood
272
+ likelihood = self ._lik (
273
+ mutations_on_edge ,
274
+ edge .span ,
275
+ timediff ,
276
+ self .mut_rate ,
277
+ normalize = False ,
278
+ )
279
+ likelihood [timediff < 0 ] = 0
280
+
281
+ return likelihood
282
+
270
283
271
284
def get_mut_lik_lower_tri (self , edge ):
272
285
"""
@@ -677,10 +690,13 @@ def inside_pass(self, *, normalize=True, cache_inside=False, progress=None):
677
690
)
678
691
edge_lik = self .lik .get_inside (daughter_val , edge )
679
692
val = self .lik .combine (val , edge_lik )
693
+ if np .all (val == 0 ):
694
+ raise ValueError
680
695
if cache_inside :
681
696
g_i [edge .id ] = edge_lik
682
697
norm [parent ] = np .max (val ) if normalize else 1
683
698
inside [parent ] = self .lik .reduce (val , norm [parent ])
699
+
684
700
if cache_inside :
685
701
self .g_i = self .lik .reduce (g_i , norm [self .ts .tables .edges .child , None ])
686
702
# Keep the results in this object
@@ -732,10 +748,10 @@ def outside_pass(
732
748
if ignore_oldest_root :
733
749
if edge .parent == self .ts .num_nodes - 1 :
734
750
continue
735
- # if edge.parent in self.fixednodes:
736
- # raise RuntimeError(
737
- # "Fixed nodes cannot currently be parents in the TS"
738
- # )
751
+ if edge .parent in self .fixednodes :
752
+ raise RuntimeError (
753
+ "Fixed nodes cannot currently be parents in the TS"
754
+ )
739
755
# Geometric scaling works exactly for all nodes fixed in graph
740
756
# but is an approximation when times are unknown.
741
757
spanfrac = edge .span / self .spans [child ]
@@ -1065,10 +1081,6 @@ def get_dates(
1065
1081
1066
1082
:return: tuple(mn_post, posterior, timepoints, eps, nodes_to_date)
1067
1083
"""
1068
- # Stuff yet to be implemented. These can be deleted once fixed
1069
- #for sample in tree_sequence.samples():
1070
- # if tree_sequence.node(sample).time != 0:
1071
- # raise NotImplementedError("Samples must all be at time 0")
1072
1084
fixed_nodes = set (tree_sequence .samples ())
1073
1085
1074
1086
# Default to not creating approximate priors unless ts has > 1000 samples
0 commit comments