Skip to content

Commit 5047529

Browse files
ColCarrolltensorflower-gardener
authored andcommitted
Disable test_gradient_with_additional_parameters for JAX backend.
PiperOrigin-RevId: 693475540
1 parent b0a692b commit 5047529

File tree

2 files changed

+23
-4
lines changed

2 files changed

+23
-4
lines changed

tensorflow_probability/python/distributions/batch_broadcast_test.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,23 @@
4040
from tensorflow_probability.python.internal import test_util
4141
from tensorflow_probability.python.random import random_ops
4242

43+
_DIFFERENT_HYPOTHESIS_KWARGS = {}
44+
45+
# This check is done on recent versions of hypothesis, but not all,
46+
# as of November 2024.
47+
if hasattr(hp.HealthCheck, 'differing_executors'):
48+
_DIFFERENT_HYPOTHESIS_KWARGS['suppress_health_check'] = [
49+
hp.HealthCheck.differing_executors
50+
]
51+
4352

4453
@test_util.test_all_tf_execution_regimes
4554
class _BatchBroadcastTest(object):
4655

4756
@hp.given(hps.data())
48-
@tfp_hps.tfp_hp_settings(default_max_examples=5)
57+
@tfp_hps.tfp_hp_settings(
58+
default_max_examples=5,
59+
**_DIFFERENT_HYPOTHESIS_KWARGS)
4960
def test_shapes(self, data):
5061
batch_shape = data.draw(tfp_hps.shapes())
5162
bcast_arg, dist_batch_shp = data.draw(
@@ -63,7 +74,9 @@ def test_shapes(self, data):
6374
dist.event_shape_tensor())
6475

6576
@hp.given(hps.data())
66-
@tfp_hps.tfp_hp_settings(default_max_examples=5)
77+
@tfp_hps.tfp_hp_settings(
78+
default_max_examples=5,
79+
**_DIFFERENT_HYPOTHESIS_KWARGS)
6780
def test_sample(self, data):
6881
batch_shape = data.draw(tfp_hps.shapes())
6982
bcast_arg, dist_batch_shp = data.draw(
@@ -109,7 +122,9 @@ def test_sample(self, data):
109122
self.assertAllClose(lp, dist.log_prob(sample2))
110123

111124
@hp.given(hps.data())
112-
@tfp_hps.tfp_hp_settings(default_max_examples=5)
125+
@tfp_hps.tfp_hp_settings(
126+
default_max_examples=5,
127+
**_DIFFERENT_HYPOTHESIS_KWARGS)
113128
def test_log_prob(self, data):
114129
batch_shape = data.draw(tfp_hps.shapes())
115130
bcast_arg, dist_batch_shp = data.draw(
@@ -235,7 +250,9 @@ def test_docstring_example(self):
235250
self.evaluate(lp)
236251

237252
@hp.given(hps.data())
238-
@tfp_hps.tfp_hp_settings(default_max_examples=5)
253+
@tfp_hps.tfp_hp_settings(
254+
default_max_examples=5,
255+
**_DIFFERENT_HYPOTHESIS_KWARGS)
239256
def test_default_bijector(self, data):
240257
batch_shape = data.draw(tfp_hps.shapes())
241258
bcast_arg, dist_batch_shp = data.draw(

tensorflow_probability/python/experimental/bijectors/scalar_function_with_inferred_inverse_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ def ildj_fn(y):
102102
self.assertAllClose(ildj, ildj_true, atol=1e-4)
103103
self.assertAllClose(ildj_grad, ildj_grad_true, rtol=1e-4)
104104

105+
@test_util.disable_test_for_backend(
106+
disable_jax=True, reason='Tracer leak from additional parameters.')
105107
@test_util.numpy_disable_gradient_test
106108
@parameterized.named_parameters(
107109
{

0 commit comments

Comments
 (0)