Skip to content

Commit c716e6f

Browse files
emilyfertigtensorflower-gardener
authored andcommitted
Remove auto_composite_tensor decorators on bijector subclasses now that it's handled in a metaclass.
PiperOrigin-RevId: 377152250
1 parent 8352cdd commit c716e6f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+13
-100
lines changed

tensorflow_probability/python/bijectors/BUILD

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -264,27 +264,26 @@ multi_substrate_py_library(
264264
name = "scale_matvec_diag",
265265
srcs = ["scale_matvec_diag.py"],
266266
deps = [
267-
":bijector",
268267
":ldj_ratio",
268+
":scale_matvec_linear_operator",
269269
# tensorflow dep,
270-
"//tensorflow_probability/python/internal:assert_util",
271270
"//tensorflow_probability/python/internal:dtype_util",
272271
"//tensorflow_probability/python/internal:parameter_properties",
272+
"//tensorflow_probability/python/internal:tensor_util",
273273
],
274274
)
275275

276276
multi_substrate_py_library(
277277
name = "scale_matvec_tril",
278278
srcs = ["scale_matvec_tril.py"],
279279
deps = [
280-
":bijector",
281280
":fill_triangular",
282-
":identity",
281+
":scale_matvec_linear_operator",
283282
# tensorflow dep,
284-
"//tensorflow_probability/python/internal:assert_util",
285283
"//tensorflow_probability/python/internal:dtype_util",
286284
"//tensorflow_probability/python/internal:parameter_properties",
287285
"//tensorflow_probability/python/internal:prefer_static",
286+
"//tensorflow_probability/python/internal:tensor_util",
288287
],
289288
)
290289

@@ -301,11 +300,9 @@ multi_substrate_py_library(
301300
name = "blockwise",
302301
srcs = ["blockwise.py"],
303302
deps = [
304-
":bijector",
305303
# numpy dep,
306304
# tensorflow dep,
307305
"//tensorflow_probability/python/internal:assert_util",
308-
"//tensorflow_probability/python/internal:auto_composite_tensor",
309306
"//tensorflow_probability/python/internal:prefer_static",
310307
"//tensorflow_probability/python/internal:tensorshape_util",
311308
],
@@ -414,12 +411,8 @@ multi_substrate_py_library(
414411
name = "exp",
415412
srcs = ["exp.py"],
416413
deps = [
417-
":bijector",
418414
":invert",
419415
":power_transform",
420-
"//tensorflow_probability/python/internal:assert_util",
421-
"//tensorflow_probability/python/internal:auto_composite_tensor",
422-
"//tensorflow_probability/python/internal:prefer_static",
423416
],
424417
)
425418

@@ -430,7 +423,6 @@ multi_substrate_py_library(
430423
":bijector",
431424
":invert",
432425
# tensorflow dep,
433-
"//tensorflow_probability/python/internal:auto_composite_tensor",
434426
],
435427
)
436428

@@ -578,7 +570,6 @@ multi_substrate_py_library(
578570
srcs = ["invert.py"],
579571
deps = [
580572
":bijector",
581-
"//tensorflow_probability/python/internal:auto_composite_tensor",
582573
],
583574
)
584575

tensorflow_probability/python/bijectors/absolute_value.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
]
3030

3131

32-
@bijector.auto_composite_tensor_bijector
3332
class AbsoluteValue(bijector.AutoCompositeTensorBijector):
3433
"""Computes `Y = g(X) = Abs(X)`, element-wise.
3534

tensorflow_probability/python/bijectors/affine.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,6 @@ def _as_tensor(x, name, dtype):
213213
return None if x is None else tf.convert_to_tensor(x, name=name, dtype=dtype)
214214

215215

216-
@bijector.auto_composite_tensor_bijector
217216
class Affine(bijector.AutoCompositeTensorBijector):
218217
"""Compute `Y = g(X; shift, scale) = scale @ X + shift`.
219218

tensorflow_probability/python/bijectors/affine_linear_operator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
]
3131

3232

33-
@bijector.auto_composite_tensor_bijector
3433
class AffineLinearOperator(bijector.AutoCompositeTensorBijector):
3534
"""Compute `Y = g(X; shift, scale) = scale @ X + shift`.
3635

tensorflow_probability/python/bijectors/ascending.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
]
3030

3131

32-
@bijector.auto_composite_tensor_bijector
3332
class Ascending(bijector.AutoCompositeTensorBijector):
3433
"""Maps unconstrained R^n to R^n in ascending order.
3534

tensorflow_probability/python/bijectors/bijector.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
import abc
2222
import contextlib
23-
import functools
2423

2524
# Dependency imports
2625
import numpy as np
@@ -1778,13 +1777,6 @@ def _composite_tensor_shape_params(self):
17781777
return ()
17791778

17801779

1781-
auto_composite_tensor_bijector = functools.partial(
1782-
auto_composite_tensor.auto_composite_tensor,
1783-
omit_kwargs=('parameters',),
1784-
non_identifying_kwargs=('name',),
1785-
module_name='tfp.bijectors')
1786-
1787-
17881780
class _AutoCompositeTensorBijectorMeta(_BijectorMeta):
17891781
"""Metaclass for `AutoCompositeTensorBijector`."""
17901782

@@ -1793,7 +1785,11 @@ def __new__(mcs, classname, baseclasses, attrs): # pylint: disable=bad-mcs-clas
17931785

17941786
cls = super(_AutoCompositeTensorBijectorMeta, mcs).__new__( # pylint: disable=too-many-function-args
17951787
mcs, classname, baseclasses, attrs)
1796-
return auto_composite_tensor_bijector(cls)
1788+
return auto_composite_tensor.auto_composite_tensor(
1789+
cls,
1790+
omit_kwargs=('parameters',),
1791+
non_identifying_kwargs=('name',),
1792+
module_name='tfp.bijectors')
17971793

17981794

17991795
class AutoCompositeTensorBijector(

tensorflow_probability/python/bijectors/blockwise.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from tensorflow_probability.python.bijectors import joint_map
2929
from tensorflow_probability.python.bijectors import split
3030
from tensorflow_probability.python.internal import assert_util
31-
from tensorflow_probability.python.internal import auto_composite_tensor
3231
from tensorflow_probability.python.internal import prefer_static as ps
3332
from tensorflow_probability.python.internal import tensorshape_util
3433

@@ -270,8 +269,7 @@ def _validate_block_sizes(block_sizes, bijectors, validate_args):
270269
return block_sizes
271270

272271

273-
@bijector_lib.auto_composite_tensor_bijector
274-
class Blockwise(_Blockwise, auto_composite_tensor.AutoCompositeTensor):
272+
class Blockwise(_Blockwise, bijector_lib.AutoCompositeTensorBijector):
275273

276274
def __new__(cls, *args, **kwargs):
277275
"""Returns a `_Blockwise` if any of `bijectors` is not `CompositeTensor."""

tensorflow_probability/python/bijectors/categorical_to_discrete.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
]
3737

3838

39-
@bijector.auto_composite_tensor_bijector
4039
class CategoricalToDiscrete(bijector.AutoCompositeTensorBijector):
4140
"""Bijector which computes `Y = g(X) = values[X]`.
4241

tensorflow_probability/python/bijectors/cholesky_outer_product.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
]
3636

3737

38-
@bijector.auto_composite_tensor_bijector
3938
class CholeskyOuterProduct(bijector.AutoCompositeTensorBijector):
4039
"""Compute `g(X) = X @ X.T`; X is lower-triangular, positive-diagonal matrix.
4140

tensorflow_probability/python/bijectors/cholesky_to_inv_cholesky.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
]
3333

3434

35-
@bijector.auto_composite_tensor_bijector
3635
class CholeskyToInvCholesky(bijector.AutoCompositeTensorBijector):
3736
"""Maps the Cholesky factor of `M` to the Cholesky factor of `M^{-1}`.
3837

0 commit comments

Comments
 (0)