@@ -378,17 +378,23 @@ def testDofChangeError(self):
378
378
smc = tfb .SoftmaxCentered ()
379
379
380
380
# Increase in event-size is the last step. No problems here.
381
- safe_bij = tfb .Chain ([smc , exp ], validate_args = True )
381
+ safe_bij = tfb .Chain ([smc , exp ],
382
+ validate_args = True ,
383
+ validate_event_size = True )
382
384
self .evaluate (safe_bij .forward_log_det_jacobian ([1. , 2. , 3. ], 1 ))
383
385
384
386
# Increase in event-size before Exp.
385
- raise_bij = tfb .Chain ([exp , smc ], validate_args = True )
387
+ raise_bij = tfb .Chain ([exp , smc ],
388
+ validate_args = True ,
389
+ validate_event_size = True )
386
390
with self .assertRaisesRegex ((ValueError , tf .errors .InvalidArgumentError ),
387
391
r".+degrees of freedom.+" ):
388
392
self .evaluate (raise_bij .forward_log_det_jacobian ([1. , 2. , 3. ], 1 ))
389
393
390
394
# When validate_args is False, warns instead of raising.
391
- warn_bij = tfb .Chain ([exp , smc ], validate_args = False )
395
+ warn_bij = tfb .Chain ([exp , smc ],
396
+ validate_args = False ,
397
+ validate_event_size = True )
392
398
with mock .patch .object (tf , "print" , return_value = tf .no_op ()) as mock_print :
393
399
self .evaluate (warn_bij .forward_log_det_jacobian ([1. , 2. , 3. ], 1 ))
394
400
print_args , _ = mock_print .call_args
0 commit comments