Skip to content

Commit 751aaa3

Browse files
axchtensorflower-gardener
authored andcommitted
Prefer + to tf.add_n in JointDistributionPinned.log_weight, because it broadcasts.
PiperOrigin-RevId: 378889317
1 parent b7fd44b commit 751aaa3

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tensorflow_probability/python/experimental/distributions/joint_distribution_pinned.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,7 @@ def log_weight(self, *args, **kwargs): # pylint: disable=g-doc-args
557557
log_weights: log-weight of the given point, i.e. the log pinned evidence.
558558
"""
559559
pin_probs = self.unnormalized_log_prob_parts(*args, **kwargs).pinned
560-
return tf.add_n(
560+
return sum( # Sum uses +, which broadcasts
561561
pin_probs.values() if isinstance(pin_probs, dict) else pin_probs)
562562

563563
@docstring_util.expand_docstring(

0 commit comments

Comments
 (0)