@@ -280,6 +280,24 @@ def _get_parameterization(self):
280280 return id (self )
281281
282282
283+ class UnspecifiedParameters (tfb .Bijector ):
284+ """A bijector that fails to pass `parameters` to the base class."""
285+
286+ def __init__ (self , loc ):
287+ self ._loc = loc
288+ super (UnspecifiedParameters , self ).__init__ (
289+ validate_args = False ,
290+ is_constant_jacobian = True ,
291+ forward_min_event_ndims = 0 ,
292+ name = 'unspecified_parameters' )
293+
294+ def _forward (self , x ):
295+ return x + self ._loc
296+
297+ def _forward_log_det_jacobian (self , x ):
298+ return tf .constant (0. , x .dtype )
299+
300+
283301@test_util .test_all_tf_execution_regimes
284302class BijectorTestEventNdims (test_util .TestCase ):
285303
@@ -440,6 +458,18 @@ def testUniqueCacheKey(self):
440458 self .assertLen (bijector_1 ._cache .weak_keys (direction = 'forward' ), 1 )
441459 self .assertLen (bijector_2 ._cache .weak_keys (direction = 'forward' ), 1 )
442460
461+ def testBijectorsWithUnspecifiedParametersDoNotShareCache (self ):
462+ bijector_1 = UnspecifiedParameters (loc = tf .constant (1. , dtype = tf .float32 ))
463+ bijector_2 = UnspecifiedParameters (loc = tf .constant (2. , dtype = tf .float32 ))
464+
465+ x = tf .constant (3. , dtype = tf .float32 )
466+ y_1 = bijector_1 .forward (x )
467+ y_2 = bijector_2 .forward (x )
468+
469+ self .assertIsNot (y_1 , y_2 )
470+ self .assertLen (bijector_1 ._cache .weak_keys (direction = 'forward' ), 1 )
471+ self .assertLen (bijector_2 ._cache .weak_keys (direction = 'forward' ), 1 )
472+
443473 def testInstanceCache (self ):
444474 instance_cache_bijector = tfb .Exp ()
445475 instance_cache_bijector ._cache = cache_util .BijectorCache (
0 commit comments