Skip to content

Commit 26aeb77

Browse files
emilyfertigtensorflower-gardener
authored andcommitted
Add an auto_composite_tensor_bijector decorator for bijectors that preserves the name attribute through flattening/unflattening and in serialization.
PiperOrigin-RevId: 374731367
1 parent ef2d464 commit 26aeb77

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+68
-207
lines changed

tensorflow_probability/python/bijectors/BUILD

Lines changed: 0 additions & 48 deletions
Large diffs are not rendered by default.

tensorflow_probability/python/bijectors/absolute_value.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,14 @@
2222

2323
from tensorflow_probability.python.bijectors import bijector
2424
from tensorflow_probability.python.internal import assert_util
25-
from tensorflow_probability.python.internal import auto_composite_tensor
2625
from tensorflow_probability.python.internal import dtype_util
2726

2827
__all__ = [
2928
'AbsoluteValue',
3029
]
3130

3231

33-
@auto_composite_tensor.auto_composite_tensor(
34-
omit_kwargs=('name',), module_name='tfp.bijectors')
32+
@bijector.auto_composite_tensor_bijector
3533
class AbsoluteValue(bijector.AutoCompositeTensorBijector):
3634
"""Computes `Y = g(X) = Abs(X)`, element-wise.
3735

tensorflow_probability/python/bijectors/ascending.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,14 @@
2222

2323
from tensorflow_probability.python.bijectors import bijector
2424
from tensorflow_probability.python.internal import assert_util
25-
from tensorflow_probability.python.internal import auto_composite_tensor
2625

2726

2827
__all__ = [
2928
'Ascending',
3029
]
3130

3231

33-
@auto_composite_tensor.auto_composite_tensor(
34-
omit_kwargs=('name',), module_name='tfp.bijectors')
32+
@bijector.auto_composite_tensor_bijector
3533
class Ascending(bijector.AutoCompositeTensorBijector):
3634
"""Maps unconstrained R^n to R^n in ascending order.
3735

tensorflow_probability/python/bijectors/bijector.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import abc
2222
import contextlib
23+
import functools
2324

2425
# Dependency imports
2526
import numpy as np
@@ -1620,6 +1621,13 @@ class MyBijector(tfb.AutoCompositeTensorBijector):
16201621
pass
16211622

16221623

1624+
auto_composite_tensor_bijector = functools.partial(
1625+
auto_composite_tensor.auto_composite_tensor,
1626+
omit_kwargs=('parameters',),
1627+
non_identifying_kwargs=('name',),
1628+
module_name='tfp.bijectors')
1629+
1630+
16231631
def check_valid_ndims(ndims, validate=True):
16241632
"""Ensures that `ndims` is a non-negative integer.
16251633

tensorflow_probability/python/bijectors/bijector_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import tensorflow.compat.v1 as tf1
2626
import tensorflow.compat.v2 as tf
2727
from tensorflow_probability.python import bijectors as tfb
28-
from tensorflow_probability.python.internal import auto_composite_tensor
28+
from tensorflow_probability.python.bijectors import bijector as bijector_lib
2929
from tensorflow_probability.python.internal import cache_util
3030
from tensorflow_probability.python.internal import tensor_util
3131
from tensorflow_probability.python.internal import test_util
@@ -768,7 +768,7 @@ def testNestedCondition(self):
768768
mock_method.assert_called_once_with(mock.ANY, arg1=arg1, arg2=arg2)
769769

770770

771-
@auto_composite_tensor.auto_composite_tensor(omit_kwargs=('name',))
771+
@bijector_lib.auto_composite_tensor_bijector
772772
class CompositeForwardBijector(tfb.AutoCompositeTensorBijector):
773773

774774
def __init__(self, scale=2., validate_args=False, name=None):

tensorflow_probability/python/bijectors/categorical_to_discrete.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@
3636
]
3737

3838

39-
class CategoricalToDiscrete(bijector.Bijector):
39+
@bijector.auto_composite_tensor_bijector
40+
class CategoricalToDiscrete(bijector.AutoCompositeTensorBijector):
4041
"""Bijector which computes `Y = g(X) = values[X]`.
4142
4243
Example Usage:

tensorflow_probability/python/bijectors/cholesky_outer_product.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424

2525
from tensorflow_probability.python.bijectors import bijector
2626
from tensorflow_probability.python.internal import assert_util
27-
from tensorflow_probability.python.internal import auto_composite_tensor
2827
from tensorflow_probability.python.internal import distribution_util
2928
from tensorflow_probability.python.internal import dtype_util
3029
from tensorflow_probability.python.internal import prefer_static as ps
@@ -36,8 +35,7 @@
3635
]
3736

3837

39-
@auto_composite_tensor.auto_composite_tensor(
40-
omit_kwargs=('name',), module_name='tfp.bijectors')
38+
@bijector.auto_composite_tensor_bijector
4139
class CholeskyOuterProduct(bijector.AutoCompositeTensorBijector):
4240
"""Compute `g(X) = X @ X.T`; X is lower-triangular, positive-diagonal matrix.
4341

tensorflow_probability/python/bijectors/cholesky_to_inv_cholesky.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from tensorflow_probability.python.bijectors import bijector
2424
from tensorflow_probability.python.bijectors.cholesky_outer_product import CholeskyOuterProduct
2525
from tensorflow_probability.python.internal import assert_util
26-
from tensorflow_probability.python.internal import auto_composite_tensor
2726
from tensorflow_probability.python.internal import dtype_util
2827
from tensorflow_probability.python.internal import prefer_static as ps
2928

@@ -33,8 +32,7 @@
3332
]
3433

3534

36-
@auto_composite_tensor.auto_composite_tensor(
37-
omit_kwargs=('name',), module_name='tfp.bijectors')
35+
@bijector.auto_composite_tensor_bijector
3836
class CholeskyToInvCholesky(bijector.AutoCompositeTensorBijector):
3937
"""Maps the Cholesky factor of `M` to the Cholesky factor of `M^{-1}`.
4038

tensorflow_probability/python/bijectors/correlation_cholesky.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424

2525
from tensorflow_probability.python.bijectors import bijector
2626
from tensorflow_probability.python.bijectors import fill_triangular
27-
from tensorflow_probability.python.internal import auto_composite_tensor
2827
from tensorflow_probability.python.internal import prefer_static as ps
2928
from tensorflow_probability.python.internal import tensorshape_util
3029

@@ -33,8 +32,7 @@
3332
]
3433

3534

36-
@auto_composite_tensor.auto_composite_tensor(
37-
omit_kwargs=('name',), module_name='tfp.bijectors')
35+
@bijector.auto_composite_tensor_bijector
3836
class CorrelationCholesky(bijector.AutoCompositeTensorBijector):
3937
"""Maps unconstrained reals to Cholesky-space correlation matrices.
4038

tensorflow_probability/python/bijectors/cumsum.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,14 @@
2020

2121
import tensorflow.compat.v2 as tf
2222
from tensorflow_probability.python.bijectors import bijector
23-
from tensorflow_probability.python.internal import auto_composite_tensor
2423
from tensorflow_probability.python.internal import prefer_static
2524

2625
__all__ = [
2726
'Cumsum',
2827
]
2928

3029

31-
@auto_composite_tensor.auto_composite_tensor(
32-
omit_kwargs=('name',), module_name='tfp.bijectors')
30+
@bijector.auto_composite_tensor_bijector
3331
class Cumsum(bijector.AutoCompositeTensorBijector):
3432
"""Computes the cumulative sum of a tensor along a specified axis.
3533

0 commit comments

Comments
 (0)