Skip to content

Commit e838d8b

Browse files
emilyfertigtensorflower-gardener
authored andcommitted
Add non_identifying_kwargs to AutoCompositeTensor. These kwargs are preserved through flatten/unflatten and appear in serializations, but are omitted from comparison/equality checks and hashes.
PiperOrigin-RevId: 374718734
1 parent 5b61370 commit e838d8b

File tree

3 files changed

+98
-31
lines changed

3 files changed

+98
-31
lines changed

tensorflow_probability/python/internal/auto_composite_tensor.py

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def _deferred_assertion_context(is_deferred=True):
6060

6161
_SENTINEL = object()
6262

63-
_AUTO_COMPOSITE_TENSOR_VERSION = 2
63+
_AUTO_COMPOSITE_TENSOR_VERSION = 3
6464

6565
# Cache maps __init__ method to signature
6666
_sig_cache = {}
@@ -171,7 +171,8 @@ class _AutoCompositeTensorTypeSpec(tf.TypeSpec):
171171
'_comparable')
172172

173173
def __init__(self, param_specs, non_tensor_params, omit_kwargs,
174-
prefer_static_value, callable_params=None):
174+
prefer_static_value, non_identifying_kwargs,
175+
callable_params=None):
175176
"""Initializes a new `_AutoCompositeTensorTypeSpec`.
176177
177178
Args:
@@ -189,6 +190,11 @@ def __init__(self, param_specs, non_tensor_params, omit_kwargs,
189190
of `Tensor`-like kwargs to the `AutoCompositeTensor`s constructor that
190191
may be stored as static values, if known. These are typically shapes or
191192
axis values.
193+
non_identifying_kwargs: Python `tuple` of strings corresponding to the
194+
names of kwargs to the `AutoCompositeTensor`s constructor whose values
195+
are not relevant to the unique identification of the
196+
`_AutoCompositeTensorTypeSpec` instance. Equality/comparison checks and
197+
`__hash__` do not depend on these kwargs.
192198
callable_params: Python `dict` of callable kwargs to the
193199
`AutoCompositeTensor`'s constructor that do not subclass
194200
`CompositeTensor`, or `None`. If `callable_params` is a non-empty
@@ -199,22 +205,32 @@ def __init__(self, param_specs, non_tensor_params, omit_kwargs,
199205
self._non_tensor_params = non_tensor_params
200206
self._omit_kwargs = omit_kwargs
201207
self._prefer_static_value = prefer_static_value
208+
self._non_identifying_kwargs = non_identifying_kwargs
202209
self._callable_params = {} if callable_params is None else callable_params
203210

204211
self._serializable = (
205212
_AUTO_COMPOSITE_TENSOR_VERSION,
206213
self._param_specs,
207214
self._non_tensor_params,
208215
self._omit_kwargs,
209-
self._prefer_static_value)
216+
self._prefer_static_value,
217+
self._non_identifying_kwargs)
210218

211-
# TODO(b/182603117): Distinguish between `omit_kwargs_from_constructor`
212-
# and `omit_kwargs_for_comparison`.
213-
self._comparable = self._serializable + (
214-
tf.nest.map_structure(id, self._callable_params),)
219+
def remove_kwargs(d):
220+
return {k: v for k, v in d.items()
221+
if k not in self._non_identifying_kwargs}
222+
223+
self._comparable = (
224+
_AUTO_COMPOSITE_TENSOR_VERSION,
225+
remove_kwargs(self._param_specs),
226+
remove_kwargs(self._non_tensor_params),
227+
self._omit_kwargs,
228+
self._prefer_static_value,
229+
self._non_identifying_kwargs,
230+
tf.nest.map_structure(id, remove_kwargs(self._callable_params)))
215231

216232
@classmethod
217-
def from_instance(cls, instance, omit_kwargs=()):
233+
def from_instance(cls, instance, omit_kwargs=(), non_identifying_kwargs=()):
218234
cls_value_type = cls.value_type.fget(None)
219235
if type(instance) is not cls_value_type: # pylint: disable=unidiomatic-typecheck
220236
raise ValueError(f'`{type(instance).__name__}` has inherited the '
@@ -245,6 +261,7 @@ def from_instance(cls, instance, omit_kwargs=()):
245261
non_tensor_params=non_tensor_params,
246262
omit_kwargs=omit_kwargs,
247263
prefer_static_value=prefer_static_value,
264+
non_identifying_kwargs=non_identifying_kwargs,
248265
callable_params=callable_params)
249266

250267
def _to_components(self, obj):
@@ -273,6 +290,9 @@ def _deserialize(cls, encoded):
273290
if version == 1:
274291
encoded = encoded + ((),)
275292
version = 2
293+
if version == 2:
294+
encoded = encoded + ((),)
295+
version = 3
276296
if version != _AUTO_COMPOSITE_TENSOR_VERSION:
277297
raise ValueError(f'Expected version {_AUTO_COMPOSITE_TENSOR_VERSION},'
278298
f' but got {version}.')
@@ -375,7 +395,8 @@ def _type_spec(self):
375395
pass
376396

377397

378-
def auto_composite_tensor(cls=None, omit_kwargs=(), module_name=None):
398+
def auto_composite_tensor(
399+
cls=None, omit_kwargs=(), non_identifying_kwargs=(), module_name=None):
379400
"""Automagically generate `CompositeTensor` behavior for `cls`.
380401
381402
`CompositeTensor` objects are able to pass in and out of `tf.function` and
@@ -499,6 +520,8 @@ def body(obj):
499520
Args:
500521
cls: The class for which to create a CompositeTensor subclass.
501522
omit_kwargs: Optional sequence of kwarg names to be omitted from the spec.
523+
non_identifying_kwargs: Optional sequence of kwarg names to be omitted from
524+
equality/comparison checks and the `__hash__` method of the spec.
502525
module_name: The module name with which to register the `TypeSpec`. If
503526
`None`, defaults to `cls.__module__`.
504527
@@ -508,6 +531,7 @@ def body(obj):
508531
if cls is None:
509532
return functools.partial(auto_composite_tensor,
510533
omit_kwargs=omit_kwargs,
534+
non_identifying_kwargs=non_identifying_kwargs,
511535
module_name=module_name)
512536

513537
if module_name is None:
@@ -537,11 +561,15 @@ def value_type(self):
537561

538562
_AlreadyCTTypeSpec.__name__ = type_spec_class_name
539563

540-
cls._type_spec = property( # pylint: disable=protected-access
541-
lambda self: _AlreadyCTTypeSpec.from_instance(self, omit_kwargs))
564+
def _type_spec(obj):
565+
return _AlreadyCTTypeSpec.from_instance(
566+
obj, omit_kwargs, non_identifying_kwargs)
567+
568+
cls._type_spec = property(_type_spec) # pylint: disable=protected-access
542569
return cls
543570

544-
clsid = (cls.__module__, cls.__name__, omit_kwargs)
571+
clsid = (cls.__module__, cls.__name__, omit_kwargs,
572+
non_identifying_kwargs)
545573

546574
# Check for subclass if retrieving from the _registry, in case the user
547575
# has redefined the class (e.g. in a REPL/notebook).
@@ -562,7 +590,8 @@ class _AutoCompositeTensor(cls, composite_tensor.CompositeTensor):
562590

563591
@property
564592
def _type_spec(self):
565-
return _GeneratedCTTypeSpec.from_instance(self, omit_kwargs)
593+
return _GeneratedCTTypeSpec.from_instance(
594+
self, omit_kwargs, non_identifying_kwargs)
566595

567596
_AutoCompositeTensor.__name__ = cls.__name__
568597
_registry[clsid] = _AutoCompositeTensor

tensorflow_probability/python/internal/auto_composite_tensor_test.py

Lines changed: 53 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -49,20 +49,20 @@
4949

5050

5151
AutoIdentity = tfp.experimental.auto_composite_tensor(
52-
tf.linalg.LinearOperatorIdentity, omit_kwargs=('name',))
52+
tf.linalg.LinearOperatorIdentity, non_identifying_kwargs=('name',))
5353
AutoDiag = tfp.experimental.auto_composite_tensor(
54-
tf.linalg.LinearOperatorDiag, omit_kwargs=('name',))
54+
tf.linalg.LinearOperatorDiag, non_identifying_kwargs=('name',))
5555
AutoBlockDiag = tfp.experimental.auto_composite_tensor(
56-
tf.linalg.LinearOperatorBlockDiag, omit_kwargs=('name',))
56+
tf.linalg.LinearOperatorBlockDiag, non_identifying_kwargs=('name',))
5757
AutoTriL = tfp.experimental.auto_composite_tensor(
58-
tf.linalg.LinearOperatorLowerTriangular, omit_kwargs=('name',))
58+
tf.linalg.LinearOperatorLowerTriangular, non_identifying_kwargs=('name',))
5959

6060
AutoNormal = tfp.experimental.auto_composite_tensor(
61-
tfd.Normal, omit_kwargs=('name',))
61+
tfd.Normal, non_identifying_kwargs=('name',))
6262
AutoIndependent = tfp.experimental.auto_composite_tensor(
63-
tfd.Independent, omit_kwargs=('name',))
63+
tfd.Independent, non_identifying_kwargs=('name',))
6464
AutoReshape = tfp.experimental.auto_composite_tensor(
65-
tfb.Reshape, omit_kwargs=('name',))
65+
tfb.Reshape, non_identifying_kwargs=('name',))
6666

6767

6868
class Model(tf.Module):
@@ -105,7 +105,7 @@ def tearDownModule():
105105
class AutoCompositeTensorTest(test_util.TestCase):
106106

107107
def test_example(self):
108-
@tfp.experimental.auto_composite_tensor(omit_kwargs=('name',))
108+
@tfp.experimental.auto_composite_tensor(non_identifying_kwargs=('name',))
109109
class Adder(object):
110110

111111
def __init__(self, x, y, name=None):
@@ -185,7 +185,7 @@ def test_preconditioner(self):
185185
tfed = tfp.experimental.distributions
186186
auto_ct_mvn_prec_linop = tfp.experimental.auto_composite_tensor(
187187
tfed.MultivariateNormalPrecisionFactorLinearOperator,
188-
omit_kwargs=('name',))
188+
non_identifying_kwargs=('name',))
189189
tril = AutoTriL(**cov_linop.cholesky().parameters)
190190
momentum_distribution = auto_ct_mvn_prec_linop(precision_factor=tril)
191191
def body(d):
@@ -408,15 +408,26 @@ def __init__(self):
408408
d_ct = AutoStandardNormal()
409409
self.assertLen(tf.nest.flatten(d_ct, expand_composites=True), 0)
410410

411+
def test_names_preserved_through_flatten(self):
412+
413+
dist = AutoNormal(0., scale=3., name='ScaleThreeNormal')
414+
flat = tf.nest.flatten(dist, expand_composites=True)
415+
unflat = tf.nest.pack_sequence_as(dist, flat, expand_composites=True)
416+
unflat_name = ('ScaleThreeNormal' if tf.executing_eagerly()
417+
else 'ScaleThreeNormal_1')
418+
self.assertEqual(unflat.name, unflat_name)
419+
411420

412421
class _TestTypeSpec(auto_composite_tensor._AutoCompositeTensorTypeSpec):
413422

414423
def __init__(self, param_specs, non_tensor_params=None, omit_kwargs=(),
415-
prefer_static_value=(), callable_params=None):
424+
prefer_static_value=(), non_identifying_kwargs=(),
425+
callable_params=None):
416426
non_tensor_params = {} if non_tensor_params is None else non_tensor_params
417427
super(_TestTypeSpec, self).__init__(
418428
param_specs, non_tensor_params=non_tensor_params,
419429
omit_kwargs=omit_kwargs, prefer_static_value=prefer_static_value,
430+
non_identifying_kwargs=non_identifying_kwargs,
420431
callable_params=callable_params)
421432

422433
@property
@@ -452,7 +463,16 @@ class AutoCompositeTensorTypeSpecTest(test_util.TestCase):
452463
'b': tfb.Scale(3.)._type_spec},
453464
omit_kwargs=('name', 'foo'),
454465
prefer_static_value=('a',),
455-
callable_params={'f': tf.math.exp}))
466+
callable_params={'f': tf.math.exp})),
467+
('DifferentNonIdentifyingKwargsValues',
468+
_TestTypeSpec(
469+
param_specs={'x': tf.TensorSpec([], tf.float64)},
470+
non_tensor_params={'name': 'MyAutoCT'},
471+
non_identifying_kwargs=('name')),
472+
_TestTypeSpec(
473+
param_specs={'x': tf.TensorSpec([], tf.float64)},
474+
non_tensor_params={'name': 'OtherAutoCT'},
475+
non_identifying_kwargs=('name'))),
456476
)
457477
def testEquality(self, v1, v2):
458478
# pylint: disable=g-generic-assert
@@ -480,7 +500,15 @@ def testEquality(self, v1, v2):
480500
_TestTypeSpec(
481501
param_specs={'a': tf.TensorSpec([3, None], tf.float32)},
482502
omit_kwargs=('name', 'foo'),
483-
callable_params={'f': tf.math.sigmoid}))
503+
callable_params={'f': tf.math.sigmoid})),
504+
('DifferentMetadata',
505+
_TestTypeSpec(
506+
param_specs={'a': tf.TensorSpec([3, 2], tf.float32)},
507+
non_tensor_params={'validate_args': True},
508+
non_identifying_kwargs=('name',)),
509+
_TestTypeSpec(
510+
param_specs={'a': tf.TensorSpec([3, None], tf.float32)},
511+
non_tensor_params={'validate_args': True})),
484512
)
485513
def testInequality(self, v1, v2):
486514
# pylint: disable=g-generic-assert
@@ -512,7 +540,16 @@ def testInequality(self, v1, v2):
512540
param_specs={'a': tf.TensorSpec([3, None], tf.float32),
513541
'b': tfb.Scale(3.)._type_spec},
514542
omit_kwargs=('name', 'foo'),
515-
callable_params={'f': tf.math.exp}))
543+
callable_params={'f': tf.math.exp})),
544+
('DifferentNonIdentifyingKwargsValues',
545+
_TestTypeSpec(
546+
param_specs={'x': tf.TensorSpec(None, tf.float64)},
547+
non_tensor_params={'name': 'MyAutoCT'},
548+
non_identifying_kwargs=('name')),
549+
_TestTypeSpec(
550+
param_specs={'x': tf.TensorSpec([], tf.float64)},
551+
non_tensor_params={'name': 'OtherAutoCT'},
552+
non_identifying_kwargs=('name'))),
516553
)
517554
def testIsCompatibleWith(self, v1, v2):
518555
self.assertTrue(v1.is_compatible_with(v2))
@@ -625,7 +662,7 @@ def testMostSpecificCompatibleTypeException(self, v1, v2):
625662
('WithoutCallable',
626663
_TestTypeSpec(
627664
param_specs={'a': tf.TensorSpec([4, 2], tf.float32)},
628-
omit_kwargs=('name',))),
665+
omit_kwargs=('parameters',), non_identifying_kwargs=('name',))),
629666
('WithCallable',
630667
_TestTypeSpec(
631668
param_specs={'a': tf.TensorSpec(None, tf.float32),
@@ -636,7 +673,8 @@ def testMostSpecificCompatibleTypeException(self, v1, v2):
636673
def testRepr(self, spec):
637674
spec_data = (auto_composite_tensor._AUTO_COMPOSITE_TENSOR_VERSION,
638675
spec._param_specs, spec._non_tensor_params, spec._omit_kwargs,
639-
spec._prefer_static_value, spec._callable_params)
676+
spec._prefer_static_value, spec._non_identifying_kwargs,
677+
spec._callable_params)
640678
self.assertEqual(repr(spec), f'_TestTypeSpec{spec_data}')
641679

642680
if __name__ == '__main__':

tensorflow_probability/python/util/deferred_tensor_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -621,9 +621,9 @@ def testRepr(self):
621621
expected = (
622622
"_TransformedVariableSpec(input_spec=TensorSpec(shape=(4, 2), "
623623
"dtype=tf.float32, name=None), "
624-
"transform_or_spec=Sigmoid_ACTTypeSpec(2, {}, "
625-
"{'low': None, 'high': None, 'validate_args': True}, ('name',), (), "
626-
"{}), dtype=<dtype: 'float64'>, name=None)")
624+
"transform_or_spec=Sigmoid_ACTTypeSpec(3, {}, {'low': None, 'high': "
625+
"None, 'validate_args': True}, ('name',), (), (), {}), dtype=<dtype: "
626+
"'float64'>, name=None)")
627627
self.assertEqual(repr(spec), expected)
628628

629629

0 commit comments

Comments
 (0)