Skip to content

Commit 0867ae3

Browse files
faizan-mtensorflower-gardener
authored andcommitted
Fix Cache Key Mismatch for Function Input Signature
- A TensorSpec with no name is now equivalent to a Tensor with the same dtype and shape - Fixed the shape mismatch issue for Tensors: (11,2) was equal to (1, 12) PiperOrigin-RevId: 390241221
1 parent 956d09f commit 0867ae3

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

tensorflow_probability/python/distributions/poisson_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,12 @@ def testSampleGPU(self):
450450
def testSampleXLA(self):
451451
self.skip_if_no_xla()
452452
if not tf.executing_eagerly(): return # jit_compile is eager-only.
453+
454+
# TODO(b/195975508): Reloading the function to reset the cache.
455+
if not test_util.JAX_MODE:
456+
poisson_lib.random_poisson = tf.function(
457+
poisson_lib.random_poisson._python_function)
458+
453459
log_rates = np.random.rand(4, 3).astype(np.float32)
454460
dist = tfd.Poisson(log_rate=log_rates, validate_args=True)
455461
# Verify the compile succeeds going all the way through the distribution.

0 commit comments

Comments
 (0)