Skip to content

Commit 85de528

Browse files
Googlertensorflower-gardener
authored andcommitted
Change validate_event_shape's default value to False.
Improves the bijector's compatibility with `vectorized_map`. PiperOrigin-RevId: 377204164
1 parent 40d9757 commit 85de528

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

tensorflow_probability/python/bijectors/chain.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ class Chain(composition.Composition):
8181
def __init__(self,
8282
bijectors=None,
8383
validate_args=False,
84-
validate_event_size=True,
84+
validate_event_size=False,
8585
parameters=None,
8686
name=None):
8787
"""Instantiates `Chain` bijector.

tensorflow_probability/python/bijectors/chain_test.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -378,17 +378,23 @@ def testDofChangeError(self):
378378
smc = tfb.SoftmaxCentered()
379379

380380
# Increase in event-size is the last step. No problems here.
381-
safe_bij = tfb.Chain([smc, exp], validate_args=True)
381+
safe_bij = tfb.Chain([smc, exp],
382+
validate_args=True,
383+
validate_event_size=True)
382384
self.evaluate(safe_bij.forward_log_det_jacobian([1., 2., 3.], 1))
383385

384386
# Increase in event-size before Exp.
385-
raise_bij = tfb.Chain([exp, smc], validate_args=True)
387+
raise_bij = tfb.Chain([exp, smc],
388+
validate_args=True,
389+
validate_event_size=True)
386390
with self.assertRaisesRegex((ValueError, tf.errors.InvalidArgumentError),
387391
r".+degrees of freedom.+"):
388392
self.evaluate(raise_bij.forward_log_det_jacobian([1., 2., 3.], 1))
389393

390394
# When validate_args is False, warns instead of raising.
391-
warn_bij = tfb.Chain([exp, smc], validate_args=False)
395+
warn_bij = tfb.Chain([exp, smc],
396+
validate_args=False,
397+
validate_event_size=True)
392398
with mock.patch.object(tf, "print", return_value=tf.no_op()) as mock_print:
393399
self.evaluate(warn_bij.forward_log_det_jacobian([1., 2., 3.], 1))
394400
print_args, _ = mock_print.call_args

0 commit comments

Comments
 (0)