Skip to content

Commit c1c3861

Browse files
committed
Remove sum_to_unity
1 parent 6b523bd commit c1c3861

File tree

2 files changed

+6
-20
lines changed

2 files changed

+6
-20
lines changed

tsdate/base.py

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -134,19 +134,6 @@ def force_probability_space(self, probability_space):
134134
else:
135135
logging.warning("Cannot force", *descr)
136136

137-
def sum_to_unity(self, arr):
138-
"""
139-
Return an array for a node in which the (untransformed) values
140-
sum to unity
141-
"""
142-
if self.probability_space == LIN:
143-
return arr / np.sum(arr)
144-
elif self.probability_space == LOG:
145-
with np.errstate(divide="ignore"):
146-
return np.log(np.exp(arr) / np.sum(np.exp(arr)))
147-
else:
148-
raise RuntimeError("Probability space is not", LIN, "or", LOG)
149-
150137
def normalize(self):
151138
"""
152139
normalize grid and fixed data so the max is one
@@ -229,8 +216,7 @@ def fill_fixed(orig, fixed_data):
229216
new_obj.fixed_data = fill_fixed(
230217
self, grid_data if fixed_data is None else fixed_data
231218
)
232-
if probability_space is None:
233-
new_obj.probability_space = self.probability_space
234-
else:
235-
new_obj.probability_space = probability_space
219+
new_obj.probability_space = self.probability_space
220+
if probability_space is not None:
221+
new_obj.force_probability_space(probability_space)
236222
return new_obj

tsdate/core.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -652,9 +652,9 @@ def inside_pass(self, *, normalize=True, cache_inside=False, progress=None):
652652
nonfixed_samples = np.intersect1d(inside.nonfixed_node_ids(), self.ts.samples())
653653
for u in nonfixed_samples:
654654
# this is in the same probability space as the prior, so we should be
655-
# OK just to copy the prior values straight in (but we should check they
656-
# are normalised so that they sum to unity)
657-
inside[u][:] = self.priors.sum_to_unity(self.priors[u])
655+
# OK just to copy the prior values straight in. It's unclear to me (Yan)
656+
# how/if they should be normalised, however
657+
inside[u][:] = self.priors[u]
658658

659659
if cache_inside:
660660
g_i = np.full(

0 commit comments

Comments
 (0)