@@ -151,7 +151,7 @@ def _lik(muts, span, dt, mutation_rate, normalize=True):
151
151
"""
152
152
ll = scipy .stats .poisson .pmf (muts , dt * mutation_rate * span )
153
153
if normalize :
154
- return ll / np .max (ll )
154
+ return ll / np .nanmax (ll )
155
155
else :
156
156
return ll
157
157
@@ -274,12 +274,12 @@ def get_mut_lik_fixed_node(self, edge):
274
274
edge .span ,
275
275
timediff ,
276
276
self .mut_rate ,
277
- normalize = False ,
277
+ normalize = self . normalize ,
278
278
)
279
+ # Prevent child from being older than parent
279
280
likelihood [timediff < 0 ] = 0
280
-
281
- return likelihood
282
281
282
+ return likelihood
283
283
284
284
def get_mut_lik_lower_tri (self , edge ):
285
285
"""
@@ -402,7 +402,7 @@ def get_fixed(self, arr, edge):
402
402
return arr * liks
403
403
404
404
def scale_geometric (self , fraction , value ):
405
- return value ** fraction
405
+ return value ** fraction
406
406
407
407
408
408
class LogLikelihoods (Likelihoods ):
@@ -442,7 +442,7 @@ def _lik(muts, span, dt, mutation_rate, normalize=True):
442
442
"""
443
443
ll = scipy .stats .poisson .logpmf (muts , dt * mutation_rate * span )
444
444
if normalize :
445
- return ll - np .max (ll )
445
+ return ll - np .nanmax (ll )
446
446
else :
447
447
return ll
448
448
@@ -696,7 +696,7 @@ def inside_pass(self, *, normalize=True, cache_inside=False, progress=None):
696
696
g_i [edge .id ] = edge_lik
697
697
norm [parent ] = np .max (val ) if normalize else 1
698
698
inside [parent ] = self .lik .reduce (val , norm [parent ])
699
-
699
+
700
700
if cache_inside :
701
701
self .g_i = self .lik .reduce (g_i , norm [self .ts .tables .edges .child , None ])
702
702
# Keep the results in this object
0 commit comments