We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 7951c1c commit 91ae835Copy full SHA for 91ae835
tensorflow_probability/python/distributions/inflated.py
@@ -148,7 +148,8 @@ def __init__(self,
148
probs_or_logits,
149
# pylint: disable=g-long-lambda
150
lambda _: tf.broadcast_to(self._inflated_loc,
151
- ps.shape(probs_or_logits))),
+ ps.shape(probs_or_logits)),
152
+ shape=probs_or_logits.shape),
153
atol=self._inflated_loc_atol,
154
rtol=self._inflated_loc_rtol,
155
validate_args=validate_args,
0 commit comments