Skip to content

Commit 76b768a

Browse files
emilyfertigtensorflower-gardener
authored andcommitted
Call auto_composite_tensor in AutoCompositeTensorBijector's metaclass, eliminating the need for the @auto_composite_tensor decorator on bijector subclasses.
PiperOrigin-RevId: 376966154
1 parent 9efb5fb commit 76b768a

File tree

5 files changed

+76
-28
lines changed

5 files changed

+76
-28
lines changed

tensorflow_probability/python/bijectors/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,6 @@ multi_substrate_py_library(
124124
],
125125
deps = [
126126
# numpy dep,
127-
# six dep,
128127
# tensorflow dep,
129128
"//tensorflow_probability/python/internal:assert_util",
130129
"//tensorflow_probability/python/internal:auto_composite_tensor",

tensorflow_probability/python/bijectors/bijector.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424

2525
# Dependency imports
2626
import numpy as np
27-
import six
2827
import tensorflow.compat.v2 as tf
2928

3029
from tensorflow_probability.python.internal import assert_util
@@ -91,8 +90,8 @@ def unflatten(info, xs):
9190
tree_util.register_pytree_node(cls, flatten, unflatten)
9291

9392

94-
@six.add_metaclass(_BijectorMeta)
95-
class Bijector(tf.Module):
93+
# TODO(emilyaf): Look at using `__init_subclass__` instead of a metaclass.
94+
class Bijector(tf.Module, metaclass=_BijectorMeta):
9695
r"""Interface for transformations of a `Distribution` sample.
9796
9897
Bijectors can be used to represent any differentiable and injective
@@ -1779,20 +1778,36 @@ def _composite_tensor_shape_params(self):
17791778
return ()
17801779

17811780

1781+
auto_composite_tensor_bijector = functools.partial(
1782+
auto_composite_tensor.auto_composite_tensor,
1783+
omit_kwargs=('parameters',),
1784+
non_identifying_kwargs=('name',),
1785+
module_name='tfp.bijectors')
1786+
1787+
1788+
class _AutoCompositeTensorBijectorMeta(_BijectorMeta):
1789+
"""Metaclass for `AutoCompositeTensorBijector`."""
1790+
1791+
def __new__(mcs, classname, baseclasses, attrs): # pylint: disable=bad-mcs-classmethod-argument
1792+
"""Give subclasses their own type_spec, not an inherited one."""
1793+
1794+
cls = super(_AutoCompositeTensorBijectorMeta, mcs).__new__( # pylint: disable=too-many-function-args
1795+
mcs, classname, baseclasses, attrs)
1796+
return auto_composite_tensor_bijector(cls)
1797+
1798+
17821799
class AutoCompositeTensorBijector(
1783-
Bijector, auto_composite_tensor.AutoCompositeTensor):
1800+
Bijector, auto_composite_tensor.AutoCompositeTensor,
1801+
metaclass=_AutoCompositeTensorBijectorMeta):
17841802
r"""Base for `CompositeTensor` bijectors with auto-generated `TypeSpec`s.
17851803
17861804
`CompositeTensor` objects are able to pass in and out of `tf.function` and
17871805
`tf.while_loop`, or serve as part of the signature of a TF saved model.
17881806
`Bijector` subclasses that follow the contract of
17891807
`tfp.experimental.auto_composite_tensor` may be defined as `CompositeTensor`s
1790-
by inheriting from `AutoCompositeTensorBijector` and applying a class
1791-
decorator as shown here:
1808+
by inheriting from `AutoCompositeTensorBijector`:
17921809
17931810
```python
1794-
@tfp.experimental.auto_composite_tensor(
1795-
omit_kwargs=('name',), module_name='my_module')
17961811
class MyBijector(tfb.AutoCompositeTensorBijector):
17971812
17981813
# The remainder of the subclass implementation is unchanged.
@@ -1801,13 +1816,6 @@ class MyBijector(tfb.AutoCompositeTensorBijector):
18011816
pass
18021817

18031818

1804-
auto_composite_tensor_bijector = functools.partial(
1805-
auto_composite_tensor.auto_composite_tensor,
1806-
omit_kwargs=('parameters',),
1807-
non_identifying_kwargs=('name',),
1808-
module_name='tfp.bijectors')
1809-
1810-
18111819
def check_valid_ndims(ndims, validate=True):
18121820
"""Ensures that `ndims` is a non-negative integer.
18131821

tensorflow_probability/python/bijectors/bijector_test.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
import tensorflow.compat.v2 as tf
2828

2929
from tensorflow_probability.python import bijectors as tfb
30-
from tensorflow_probability.python.bijectors import bijector as bijector_lib
3130
from tensorflow_probability.python.internal import cache_util
3231
from tensorflow_probability.python.internal import parameter_properties
3332
from tensorflow_probability.python.internal import prefer_static as ps
@@ -875,11 +874,10 @@ def testNestedCondition(self):
875874
mock_method.assert_called_once_with(mock.ANY, arg1=arg1, arg2=arg2)
876875

877876

878-
@bijector_lib.auto_composite_tensor_bijector
879877
class CompositeForwardBijector(tfb.AutoCompositeTensorBijector):
880878

881-
def __init__(self, scale=2., validate_args=False, name=None):
882-
parameters = dict(locals())
879+
def __init__(self, scale=2., validate_args=False, parameters=None, name=None):
880+
parameters = dict(locals()) if parameters is None else parameters
883881
with tf.name_scope(name or 'forward_only') as name:
884882
self._scale = tensor_util.convert_nonref_to_tensor(
885883
scale,
@@ -897,6 +895,14 @@ def _forward_log_det_jacobian(self, _):
897895
return tf.math.log(self._scale)
898896

899897

898+
class CompositeForwardScaleThree(CompositeForwardBijector):
899+
900+
def __init__(self, name='scale_three'):
901+
parameters = dict(locals())
902+
super(CompositeForwardScaleThree, self).__init__(
903+
scale=3., parameters=parameters, name=name)
904+
905+
900906
@test_util.test_all_tf_execution_regimes
901907
class AutoCompositeTensorBijectorTest(test_util.TestCase):
902908

@@ -917,6 +923,15 @@ def test_disable_ct_bijector(self):
917923
non_ct_bijector.forward(x),
918924
tf.function(lambda b: b.forward(x))(unflat))
919925

926+
def test_composite_tensor_subclass(self):
927+
928+
bij = CompositeForwardScaleThree()
929+
self.assertIs(bij._type_spec.value_type, type(bij))
930+
931+
flat = tf.nest.flatten(bij, expand_composites=True)
932+
unflat = tf.nest.pack_sequence_as(bij, flat, expand_composites=True)
933+
self.assertIsInstance(unflat, CompositeForwardScaleThree)
934+
920935

921936
if __name__ == '__main__':
922937
tf.test.main()

tensorflow_probability/python/internal/auto_composite_tensor.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,32 @@ def _type_spec(self):
395395
pass
396396

397397

398+
def type_spec_register(name, allow_overwrite=True):
399+
"""Decorator used to register a unique name for a TypeSpec subclass.
400+
401+
Unlike TensorFlow's `type_spec.register`, this function allows a new
402+
`TypeSpec` to be registered with a `name` that already appears in the
403+
registry (overwriting the `TypeSpec` already registered with that name). This
404+
allows for re-definition of `AutoCompositeTensor` subclasses in test
405+
environments and iPython.
406+
407+
Args:
408+
name: The name of the type spec. Must have the form
409+
`"{project_name}.{type_name}"`. E.g. `"my_project.MyTypeSpec"`.
410+
allow_overwrite: `bool`, if `True` then the entry in the `TypeSpec` registry
411+
keyed by `name` will be overwritten if it exists. If `False`, then
412+
behavior is the same as `type_spec.register`.
413+
414+
Returns:
415+
A class decorator that registers the decorated class with the given name.
416+
"""
417+
# pylint: disable=protected-access
418+
if allow_overwrite and name in type_spec._NAME_TO_TYPE_SPEC:
419+
type_spec._TYPE_SPEC_TO_NAME.pop(
420+
type_spec._NAME_TO_TYPE_SPEC.pop(name))
421+
return type_spec.register(name)
422+
423+
398424
def auto_composite_tensor(
399425
cls=None, omit_kwargs=(), non_identifying_kwargs=(), module_name=None):
400426
"""Automagically generate `CompositeTensor` behavior for `cls`.
@@ -540,19 +566,13 @@ def body(obj):
540566
type_spec_class_name = f'{cls.__name__}_ACTTypeSpec'
541567
type_spec_name = f'{module_name}.{type_spec_class_name}'
542568

543-
try:
544-
ts = type_spec.lookup(type_spec_name)
545-
return ts.value_type.fget(None)
546-
except ValueError:
547-
pass
548-
549569
# If the declared class is already a CompositeTensor subclass, we can avoid
550570
# affecting the actual type of the returned class. Otherwise, we need to
551571
# explicitly mix in the CT type, and hence create and return a newly
552572
# synthesized type.
553573
if issubclass(cls, composite_tensor.CompositeTensor):
554574

555-
@type_spec.register(type_spec_name)
575+
@type_spec_register(type_spec_name)
556576
class _AlreadyCTTypeSpec(_AutoCompositeTensorTypeSpec):
557577

558578
@property
@@ -576,7 +596,6 @@ def _type_spec(obj):
576596
if clsid in _registry and issubclass(_registry[clsid], cls):
577597
return _registry[clsid]
578598

579-
@type_spec.register(type_spec_name)
580599
class _GeneratedCTTypeSpec(_AutoCompositeTensorTypeSpec):
581600

582601
@property
@@ -595,4 +614,5 @@ def _type_spec(self):
595614

596615
_AutoCompositeTensor.__name__ = cls.__name__
597616
_registry[clsid] = _AutoCompositeTensor
617+
type_spec_register(type_spec_name)(_GeneratedCTTypeSpec)
598618
return _AutoCompositeTensor

tensorflow_probability/python/internal/backend/numpy/type_spec.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,19 @@
1414
# ============================================================================
1515
"""Numpy stub for `type_spec`."""
1616

17+
import re
18+
1719
__all__ = [
1820
'lookup',
1921
'register',
2022
'BatchableTypeSpec',
2123
'TypeSpec',
2224
]
2325

26+
_TYPE_SPEC_TO_NAME = {}
27+
_NAME_TO_TYPE_SPEC = {}
28+
_REGISTERED_NAME_RE = re.compile(r'^(\w+\.)+\w+$')
29+
2430

2531
def register(_):
2632
"""No-op for registering a `tf.TypeSpec` for `saved_model`."""

0 commit comments

Comments
 (0)