Skip to content

Commit 974038d

Browse files
committed
Add unity normalization constant to unfixed leaf nodes
1 parent de1c751 commit 974038d

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

tsdate/core.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -649,9 +649,7 @@ def inside_pass(self, *, normalize=True, cache_inside=False, progress=None):
649649
)
650650
# It is possible that a simple node is non-fixed, in which case we want to
651651
# provide an inside array that reflects the prior distribution
652-
nonfixed_samples = np.intersect1d(
653-
self.priors.nonfixed_node_ids(), self.ts.samples()
654-
)
652+
nonfixed_samples = np.intersect1d(inside.nonfixed_node_ids(), self.ts.samples())
655653
for u in nonfixed_samples:
656654
# this is in the same probability space as the prior, so we should be
657655
# OK just to copy the prior values straight in (but we should check they
@@ -663,6 +661,8 @@ def inside_pass(self, *, normalize=True, cache_inside=False, progress=None):
663661
(self.ts.num_edges, self.lik.grid_size), self.lik.identity_constant
664662
)
665663
norm = np.full(self.ts.num_nodes, np.nan)
664+
to_visit = np.zeros(self.ts.num_nodes, dtype=bool)
665+
to_visit[inside.nonfixed_node_ids()] = True
666666
# Iterate through the nodes via groupby on parent node
667667
for parent, edges in tqdm(
668668
self.edges_by_parent_asc(),
@@ -707,6 +707,12 @@ def inside_pass(self, *, normalize=True, cache_inside=False, progress=None):
707707
g_i[edge.id] = edge_lik
708708
norm[parent] = np.max(val) if normalize else 1
709709
inside[parent] = self.lik.reduce(val, norm[parent])
710+
to_visit[parent] = False
711+
712+
# There may be nodes that are not parents but are also not fixed (e.g.
713+
# undated sample nodes). These need an identity normalization constant
714+
for unfixed_unvisited in np.where(to_visit)[0]:
715+
norm[unfixed_unvisited] = 1
710716

711717
if cache_inside:
712718
self.g_i = self.lik.reduce(g_i, norm[self.ts.tables.edges.child, None])

0 commit comments

Comments
 (0)