Skip to content

Commit 389b346

Browse files
Use prefer_static.shape over tf.shape in inflated.py's broadcasting of
inflated_loc, to be support jax.jit. PiperOrigin-RevId: 475822622
1 parent 890f555 commit 389b346

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

tensorflow_probability/python/distributions/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,6 +1014,7 @@ multi_substrate_py_library(
10141014
"//tensorflow_probability/python/internal:auto_composite_tensor",
10151015
"//tensorflow_probability/python/internal:dtype_util",
10161016
"//tensorflow_probability/python/internal:parameter_properties",
1017+
"//tensorflow_probability/python/internal:prefer_static",
10171018
"//tensorflow_probability/python/internal:samplers",
10181019
"//tensorflow_probability/python/internal:tensor_util",
10191020
"//tensorflow_probability/python/util:deferred_tensor",

tensorflow_probability/python/distributions/inflated.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from tensorflow_probability.python.internal import auto_composite_tensor
2929
from tensorflow_probability.python.internal import dtype_util
3030
from tensorflow_probability.python.internal import parameter_properties
31+
from tensorflow_probability.python.internal import prefer_static as ps
3132
from tensorflow_probability.python.internal import samplers
3233
from tensorflow_probability.python.internal import tensor_util
3334
from tensorflow_probability.python.util.deferred_tensor import DeferredTensor
@@ -147,9 +148,7 @@ def __init__(self,
147148
probs_or_logits,
148149
# pylint: disable=g-long-lambda
149150
lambda _: tf.broadcast_to(self._inflated_loc,
150-
tf.shape(probs_or_logits))
151-
# pylint: enable=g-long-lambda
152-
),
151+
ps.shape(probs_or_logits))),
153152
atol=self._inflated_loc_atol,
154153
rtol=self._inflated_loc_rtol,
155154
validate_args=validate_args,

0 commit comments

Comments
 (0)