4949
5050
5151AutoIdentity = tfp .experimental .auto_composite_tensor (
52- tf .linalg .LinearOperatorIdentity , omit_kwargs = ('name' ,))
52+ tf .linalg .LinearOperatorIdentity , non_identifying_kwargs = ('name' ,))
5353AutoDiag = tfp .experimental .auto_composite_tensor (
54- tf .linalg .LinearOperatorDiag , omit_kwargs = ('name' ,))
54+ tf .linalg .LinearOperatorDiag , non_identifying_kwargs = ('name' ,))
5555AutoBlockDiag = tfp .experimental .auto_composite_tensor (
56- tf .linalg .LinearOperatorBlockDiag , omit_kwargs = ('name' ,))
56+ tf .linalg .LinearOperatorBlockDiag , non_identifying_kwargs = ('name' ,))
5757AutoTriL = tfp .experimental .auto_composite_tensor (
58- tf .linalg .LinearOperatorLowerTriangular , omit_kwargs = ('name' ,))
58+ tf .linalg .LinearOperatorLowerTriangular , non_identifying_kwargs = ('name' ,))
5959
6060AutoNormal = tfp .experimental .auto_composite_tensor (
61- tfd .Normal , omit_kwargs = ('name' ,))
61+ tfd .Normal , non_identifying_kwargs = ('name' ,))
6262AutoIndependent = tfp .experimental .auto_composite_tensor (
63- tfd .Independent , omit_kwargs = ('name' ,))
63+ tfd .Independent , non_identifying_kwargs = ('name' ,))
6464AutoReshape = tfp .experimental .auto_composite_tensor (
65- tfb .Reshape , omit_kwargs = ('name' ,))
65+ tfb .Reshape , non_identifying_kwargs = ('name' ,))
6666
6767
6868class Model (tf .Module ):
@@ -105,7 +105,7 @@ def tearDownModule():
105105class AutoCompositeTensorTest (test_util .TestCase ):
106106
107107 def test_example (self ):
108- @tfp .experimental .auto_composite_tensor (omit_kwargs = ('name' ,))
108+ @tfp .experimental .auto_composite_tensor (non_identifying_kwargs = ('name' ,))
109109 class Adder (object ):
110110
111111 def __init__ (self , x , y , name = None ):
@@ -185,7 +185,7 @@ def test_preconditioner(self):
185185 tfed = tfp .experimental .distributions
186186 auto_ct_mvn_prec_linop = tfp .experimental .auto_composite_tensor (
187187 tfed .MultivariateNormalPrecisionFactorLinearOperator ,
188- omit_kwargs = ('name' ,))
188+ non_identifying_kwargs = ('name' ,))
189189 tril = AutoTriL (** cov_linop .cholesky ().parameters )
190190 momentum_distribution = auto_ct_mvn_prec_linop (precision_factor = tril )
191191 def body (d ):
@@ -408,15 +408,26 @@ def __init__(self):
408408 d_ct = AutoStandardNormal ()
409409 self .assertLen (tf .nest .flatten (d_ct , expand_composites = True ), 0 )
410410
411+ def test_names_preserved_through_flatten (self ):
412+
413+ dist = AutoNormal (0. , scale = 3. , name = 'ScaleThreeNormal' )
414+ flat = tf .nest .flatten (dist , expand_composites = True )
415+ unflat = tf .nest .pack_sequence_as (dist , flat , expand_composites = True )
416+ unflat_name = ('ScaleThreeNormal' if tf .executing_eagerly ()
417+ else 'ScaleThreeNormal_1' )
418+ self .assertEqual (unflat .name , unflat_name )
419+
411420
412421class _TestTypeSpec (auto_composite_tensor ._AutoCompositeTensorTypeSpec ):
413422
414423 def __init__ (self , param_specs , non_tensor_params = None , omit_kwargs = (),
415- prefer_static_value = (), callable_params = None ):
424+ prefer_static_value = (), non_identifying_kwargs = (),
425+ callable_params = None ):
416426 non_tensor_params = {} if non_tensor_params is None else non_tensor_params
417427 super (_TestTypeSpec , self ).__init__ (
418428 param_specs , non_tensor_params = non_tensor_params ,
419429 omit_kwargs = omit_kwargs , prefer_static_value = prefer_static_value ,
430+ non_identifying_kwargs = non_identifying_kwargs ,
420431 callable_params = callable_params )
421432
422433 @property
@@ -452,7 +463,16 @@ class AutoCompositeTensorTypeSpecTest(test_util.TestCase):
452463 'b' : tfb .Scale (3. )._type_spec },
453464 omit_kwargs = ('name' , 'foo' ),
454465 prefer_static_value = ('a' ,),
455- callable_params = {'f' : tf .math .exp }))
466+ callable_params = {'f' : tf .math .exp })),
467+ ('DifferentNonIdentifyingKwargsValues' ,
468+ _TestTypeSpec (
469+ param_specs = {'x' : tf .TensorSpec ([], tf .float64 )},
470+ non_tensor_params = {'name' : 'MyAutoCT' },
471+ non_identifying_kwargs = ('name' )),
472+ _TestTypeSpec (
473+ param_specs = {'x' : tf .TensorSpec ([], tf .float64 )},
474+ non_tensor_params = {'name' : 'OtherAutoCT' },
475+ non_identifying_kwargs = ('name' ))),
456476 )
457477 def testEquality (self , v1 , v2 ):
458478 # pylint: disable=g-generic-assert
@@ -480,7 +500,15 @@ def testEquality(self, v1, v2):
480500 _TestTypeSpec (
481501 param_specs = {'a' : tf .TensorSpec ([3 , None ], tf .float32 )},
482502 omit_kwargs = ('name' , 'foo' ),
483- callable_params = {'f' : tf .math .sigmoid }))
503+ callable_params = {'f' : tf .math .sigmoid })),
504+ ('DifferentMetadata' ,
505+ _TestTypeSpec (
506+ param_specs = {'a' : tf .TensorSpec ([3 , 2 ], tf .float32 )},
507+ non_tensor_params = {'validate_args' : True },
508+ non_identifying_kwargs = ('name' ,)),
509+ _TestTypeSpec (
510+ param_specs = {'a' : tf .TensorSpec ([3 , None ], tf .float32 )},
511+ non_tensor_params = {'validate_args' : True })),
484512 )
485513 def testInequality (self , v1 , v2 ):
486514 # pylint: disable=g-generic-assert
@@ -512,7 +540,16 @@ def testInequality(self, v1, v2):
512540 param_specs = {'a' : tf .TensorSpec ([3 , None ], tf .float32 ),
513541 'b' : tfb .Scale (3. )._type_spec },
514542 omit_kwargs = ('name' , 'foo' ),
515- callable_params = {'f' : tf .math .exp }))
543+ callable_params = {'f' : tf .math .exp })),
544+ ('DifferentNonIdentifyingKwargsValues' ,
545+ _TestTypeSpec (
546+ param_specs = {'x' : tf .TensorSpec (None , tf .float64 )},
547+ non_tensor_params = {'name' : 'MyAutoCT' },
548+ non_identifying_kwargs = ('name' )),
549+ _TestTypeSpec (
550+ param_specs = {'x' : tf .TensorSpec ([], tf .float64 )},
551+ non_tensor_params = {'name' : 'OtherAutoCT' },
552+ non_identifying_kwargs = ('name' ))),
516553 )
517554 def testIsCompatibleWith (self , v1 , v2 ):
518555 self .assertTrue (v1 .is_compatible_with (v2 ))
@@ -625,7 +662,7 @@ def testMostSpecificCompatibleTypeException(self, v1, v2):
625662 ('WithoutCallable' ,
626663 _TestTypeSpec (
627664 param_specs = {'a' : tf .TensorSpec ([4 , 2 ], tf .float32 )},
628- omit_kwargs = ('name' ,))),
665+ omit_kwargs = ('parameters' ,), non_identifying_kwargs = ( ' name' ,))),
629666 ('WithCallable' ,
630667 _TestTypeSpec (
631668 param_specs = {'a' : tf .TensorSpec (None , tf .float32 ),
@@ -636,7 +673,8 @@ def testMostSpecificCompatibleTypeException(self, v1, v2):
636673 def testRepr (self , spec ):
637674 spec_data = (auto_composite_tensor ._AUTO_COMPOSITE_TENSOR_VERSION ,
638675 spec ._param_specs , spec ._non_tensor_params , spec ._omit_kwargs ,
639- spec ._prefer_static_value , spec ._callable_params )
676+ spec ._prefer_static_value , spec ._non_identifying_kwargs ,
677+ spec ._callable_params )
640678 self .assertEqual (repr (spec ), f'_TestTypeSpec{ spec_data } ' )
641679
642680if __name__ == '__main__' :
0 commit comments