Skip to content

Commit b5e9bba

Browse files
committed
fix to normalization
1 parent 34c383c commit b5e9bba

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

tsdate/core.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def _lik(muts, span, dt, mutation_rate, normalize=True):
151151
"""
152152
ll = scipy.stats.poisson.pmf(muts, dt * mutation_rate * span)
153153
if normalize:
154-
return ll / np.max(ll)
154+
return ll / np.nanmax(ll)
155155
else:
156156
return ll
157157

@@ -274,12 +274,12 @@ def get_mut_lik_fixed_node(self, edge):
274274
edge.span,
275275
timediff,
276276
self.mut_rate,
277-
normalize=False,
277+
normalize=self.normalize,
278278
)
279+
# Prevent child from being older than parent
279280
likelihood[timediff < 0] = 0
280-
281-
return likelihood
282281

282+
return likelihood
283283

284284
def get_mut_lik_lower_tri(self, edge):
285285
"""
@@ -402,7 +402,7 @@ def get_fixed(self, arr, edge):
402402
return arr * liks
403403

404404
def scale_geometric(self, fraction, value):
405-
return value ** fraction
405+
return value**fraction
406406

407407

408408
class LogLikelihoods(Likelihoods):
@@ -442,7 +442,7 @@ def _lik(muts, span, dt, mutation_rate, normalize=True):
442442
"""
443443
ll = scipy.stats.poisson.logpmf(muts, dt * mutation_rate * span)
444444
if normalize:
445-
return ll - np.max(ll)
445+
return ll - np.nanmax(ll)
446446
else:
447447
return ll
448448

@@ -696,7 +696,7 @@ def inside_pass(self, *, normalize=True, cache_inside=False, progress=None):
696696
g_i[edge.id] = edge_lik
697697
norm[parent] = np.max(val) if normalize else 1
698698
inside[parent] = self.lik.reduce(val, norm[parent])
699-
699+
700700
if cache_inside:
701701
self.g_i = self.lik.reduce(g_i, norm[self.ts.tables.edges.child, None])
702702
# Keep the results in this object

0 commit comments

Comments
 (0)