49
49
50
50
51
51
AutoIdentity = tfp .experimental .auto_composite_tensor (
52
- tf .linalg .LinearOperatorIdentity , omit_kwargs = ('name' ,))
52
+ tf .linalg .LinearOperatorIdentity , non_identifying_kwargs = ('name' ,))
53
53
AutoDiag = tfp .experimental .auto_composite_tensor (
54
- tf .linalg .LinearOperatorDiag , omit_kwargs = ('name' ,))
54
+ tf .linalg .LinearOperatorDiag , non_identifying_kwargs = ('name' ,))
55
55
AutoBlockDiag = tfp .experimental .auto_composite_tensor (
56
- tf .linalg .LinearOperatorBlockDiag , omit_kwargs = ('name' ,))
56
+ tf .linalg .LinearOperatorBlockDiag , non_identifying_kwargs = ('name' ,))
57
57
AutoTriL = tfp .experimental .auto_composite_tensor (
58
- tf .linalg .LinearOperatorLowerTriangular , omit_kwargs = ('name' ,))
58
+ tf .linalg .LinearOperatorLowerTriangular , non_identifying_kwargs = ('name' ,))
59
59
60
60
AutoNormal = tfp .experimental .auto_composite_tensor (
61
- tfd .Normal , omit_kwargs = ('name' ,))
61
+ tfd .Normal , non_identifying_kwargs = ('name' ,))
62
62
AutoIndependent = tfp .experimental .auto_composite_tensor (
63
- tfd .Independent , omit_kwargs = ('name' ,))
63
+ tfd .Independent , non_identifying_kwargs = ('name' ,))
64
64
AutoReshape = tfp .experimental .auto_composite_tensor (
65
- tfb .Reshape , omit_kwargs = ('name' ,))
65
+ tfb .Reshape , non_identifying_kwargs = ('name' ,))
66
66
67
67
68
68
class Model (tf .Module ):
@@ -105,7 +105,7 @@ def tearDownModule():
105
105
class AutoCompositeTensorTest (test_util .TestCase ):
106
106
107
107
def test_example (self ):
108
- @tfp .experimental .auto_composite_tensor (omit_kwargs = ('name' ,))
108
+ @tfp .experimental .auto_composite_tensor (non_identifying_kwargs = ('name' ,))
109
109
class Adder (object ):
110
110
111
111
def __init__ (self , x , y , name = None ):
@@ -185,7 +185,7 @@ def test_preconditioner(self):
185
185
tfed = tfp .experimental .distributions
186
186
auto_ct_mvn_prec_linop = tfp .experimental .auto_composite_tensor (
187
187
tfed .MultivariateNormalPrecisionFactorLinearOperator ,
188
- omit_kwargs = ('name' ,))
188
+ non_identifying_kwargs = ('name' ,))
189
189
tril = AutoTriL (** cov_linop .cholesky ().parameters )
190
190
momentum_distribution = auto_ct_mvn_prec_linop (precision_factor = tril )
191
191
def body (d ):
@@ -408,15 +408,26 @@ def __init__(self):
408
408
d_ct = AutoStandardNormal ()
409
409
self .assertLen (tf .nest .flatten (d_ct , expand_composites = True ), 0 )
410
410
411
+ def test_names_preserved_through_flatten (self ):
412
+
413
+ dist = AutoNormal (0. , scale = 3. , name = 'ScaleThreeNormal' )
414
+ flat = tf .nest .flatten (dist , expand_composites = True )
415
+ unflat = tf .nest .pack_sequence_as (dist , flat , expand_composites = True )
416
+ unflat_name = ('ScaleThreeNormal' if tf .executing_eagerly ()
417
+ else 'ScaleThreeNormal_1' )
418
+ self .assertEqual (unflat .name , unflat_name )
419
+
411
420
412
421
class _TestTypeSpec (auto_composite_tensor ._AutoCompositeTensorTypeSpec ):
413
422
414
423
def __init__ (self , param_specs , non_tensor_params = None , omit_kwargs = (),
415
- prefer_static_value = (), callable_params = None ):
424
+ prefer_static_value = (), non_identifying_kwargs = (),
425
+ callable_params = None ):
416
426
non_tensor_params = {} if non_tensor_params is None else non_tensor_params
417
427
super (_TestTypeSpec , self ).__init__ (
418
428
param_specs , non_tensor_params = non_tensor_params ,
419
429
omit_kwargs = omit_kwargs , prefer_static_value = prefer_static_value ,
430
+ non_identifying_kwargs = non_identifying_kwargs ,
420
431
callable_params = callable_params )
421
432
422
433
@property
@@ -452,7 +463,16 @@ class AutoCompositeTensorTypeSpecTest(test_util.TestCase):
452
463
'b' : tfb .Scale (3. )._type_spec },
453
464
omit_kwargs = ('name' , 'foo' ),
454
465
prefer_static_value = ('a' ,),
455
- callable_params = {'f' : tf .math .exp }))
466
+ callable_params = {'f' : tf .math .exp })),
467
+ ('DifferentNonIdentifyingKwargsValues' ,
468
+ _TestTypeSpec (
469
+ param_specs = {'x' : tf .TensorSpec ([], tf .float64 )},
470
+ non_tensor_params = {'name' : 'MyAutoCT' },
471
+ non_identifying_kwargs = ('name' )),
472
+ _TestTypeSpec (
473
+ param_specs = {'x' : tf .TensorSpec ([], tf .float64 )},
474
+ non_tensor_params = {'name' : 'OtherAutoCT' },
475
+ non_identifying_kwargs = ('name' ))),
456
476
)
457
477
def testEquality (self , v1 , v2 ):
458
478
# pylint: disable=g-generic-assert
@@ -480,7 +500,15 @@ def testEquality(self, v1, v2):
480
500
_TestTypeSpec (
481
501
param_specs = {'a' : tf .TensorSpec ([3 , None ], tf .float32 )},
482
502
omit_kwargs = ('name' , 'foo' ),
483
- callable_params = {'f' : tf .math .sigmoid }))
503
+ callable_params = {'f' : tf .math .sigmoid })),
504
+ ('DifferentMetadata' ,
505
+ _TestTypeSpec (
506
+ param_specs = {'a' : tf .TensorSpec ([3 , 2 ], tf .float32 )},
507
+ non_tensor_params = {'validate_args' : True },
508
+ non_identifying_kwargs = ('name' ,)),
509
+ _TestTypeSpec (
510
+ param_specs = {'a' : tf .TensorSpec ([3 , None ], tf .float32 )},
511
+ non_tensor_params = {'validate_args' : True })),
484
512
)
485
513
def testInequality (self , v1 , v2 ):
486
514
# pylint: disable=g-generic-assert
@@ -512,7 +540,16 @@ def testInequality(self, v1, v2):
512
540
param_specs = {'a' : tf .TensorSpec ([3 , None ], tf .float32 ),
513
541
'b' : tfb .Scale (3. )._type_spec },
514
542
omit_kwargs = ('name' , 'foo' ),
515
- callable_params = {'f' : tf .math .exp }))
543
+ callable_params = {'f' : tf .math .exp })),
544
+ ('DifferentNonIdentifyingKwargsValues' ,
545
+ _TestTypeSpec (
546
+ param_specs = {'x' : tf .TensorSpec (None , tf .float64 )},
547
+ non_tensor_params = {'name' : 'MyAutoCT' },
548
+ non_identifying_kwargs = ('name' )),
549
+ _TestTypeSpec (
550
+ param_specs = {'x' : tf .TensorSpec ([], tf .float64 )},
551
+ non_tensor_params = {'name' : 'OtherAutoCT' },
552
+ non_identifying_kwargs = ('name' ))),
516
553
)
517
554
def testIsCompatibleWith (self , v1 , v2 ):
518
555
self .assertTrue (v1 .is_compatible_with (v2 ))
@@ -625,7 +662,7 @@ def testMostSpecificCompatibleTypeException(self, v1, v2):
625
662
('WithoutCallable' ,
626
663
_TestTypeSpec (
627
664
param_specs = {'a' : tf .TensorSpec ([4 , 2 ], tf .float32 )},
628
- omit_kwargs = ('name' ,))),
665
+ omit_kwargs = ('parameters' ,), non_identifying_kwargs = ( ' name' ,))),
629
666
('WithCallable' ,
630
667
_TestTypeSpec (
631
668
param_specs = {'a' : tf .TensorSpec (None , tf .float32 ),
@@ -636,7 +673,8 @@ def testMostSpecificCompatibleTypeException(self, v1, v2):
636
673
def testRepr (self , spec ):
637
674
spec_data = (auto_composite_tensor ._AUTO_COMPOSITE_TENSOR_VERSION ,
638
675
spec ._param_specs , spec ._non_tensor_params , spec ._omit_kwargs ,
639
- spec ._prefer_static_value , spec ._callable_params )
676
+ spec ._prefer_static_value , spec ._non_identifying_kwargs ,
677
+ spec ._callable_params )
640
678
self .assertEqual (repr (spec ), f'_TestTypeSpec{ spec_data } ' )
641
679
642
680
if __name__ == '__main__' :
0 commit comments