Skip to content

Commit 91ae835

Browse files
Let DeferredTensor know the static shape of the broadcasted inflated_loc.
PiperOrigin-RevId: 475879856
1 parent 7951c1c commit 91ae835

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

tensorflow_probability/python/distributions/inflated.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,8 @@ def __init__(self,
148148
probs_or_logits,
149149
# pylint: disable=g-long-lambda
150150
lambda _: tf.broadcast_to(self._inflated_loc,
151-
ps.shape(probs_or_logits))),
151+
ps.shape(probs_or_logits)),
152+
shape=probs_or_logits.shape),
152153
atol=self._inflated_loc_atol,
153154
rtol=self._inflated_loc_rtol,
154155
validate_args=validate_args,

0 commit comments

Comments
 (0)