Skip to content

Commit 355f190

Browse files
emilyfertigtensorflower-gardener
authored andcommitted
Enable AutoCompositeTensor for `LinearOperator' bijectors.
PiperOrigin-RevId: 375531427
1 parent f58370f commit 355f190

File tree

7 files changed

+14
-11
lines changed

7 files changed

+14
-11
lines changed

tensorflow_probability/python/bijectors/affine.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,8 @@ def _as_tensor(x, name, dtype):
213213
return None if x is None else tf.convert_to_tensor(x, name=name, dtype=dtype)
214214

215215

216-
class Affine(bijector.Bijector):
216+
@bijector.auto_composite_tensor_bijector
217+
class Affine(bijector.AutoCompositeTensorBijector):
217218
"""Compute `Y = g(X; shift, scale) = scale @ X + shift`.
218219
219220
Here `scale = c * I + diag(D1) + tril(L) + V @ diag(D2) @ V.T`.

tensorflow_probability/python/bijectors/affine_linear_operator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@
3030
]
3131

3232

33-
class AffineLinearOperator(bijector.Bijector):
33+
@bijector.auto_composite_tensor_bijector
34+
class AffineLinearOperator(bijector.AutoCompositeTensorBijector):
3435
"""Compute `Y = g(X; shift, scale) = scale @ X + shift`.
3536
3637
`shift` is a numeric `Tensor` and `scale` is a `LinearOperator`.

tensorflow_probability/python/bijectors/bijector_properties_test.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,11 +196,9 @@
196196
})
197197
COMPOSITE_TENSOR_ATOL = collections.defaultdict(lambda: 1e-6)
198198

199-
# TODO(b/182603117): Enable AutoCT for meta-bijectors and LinearOperator.
199+
# TODO(b/182603117): Enable AutoCT for meta-bijectors.
200200
AUTO_COMPOSITE_TENSOR_IS_BROKEN = [
201201
'FillScaleTriL',
202-
'ScaleMatvecDiag',
203-
'ScaleMatvecTriL',
204202
]
205203

206204

tensorflow_probability/python/bijectors/scale_matvec_diag.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import tensorflow.compat.v2 as tf
2222

23+
from tensorflow_probability.python.bijectors import bijector
2324
from tensorflow_probability.python.bijectors import ldj_ratio
2425
from tensorflow_probability.python.bijectors import scale_matvec_linear_operator
2526
from tensorflow_probability.python.internal import dtype_util
@@ -32,8 +33,7 @@
3233
]
3334

3435

35-
# TODO(b/182603117): Enable AutoCompositeTensor once LinearOperators are
36-
# converted to CompositeTensor.
36+
@bijector.auto_composite_tensor_bijector
3737
class ScaleMatvecDiag(scale_matvec_linear_operator.ScaleMatvecLinearOperator):
3838
"""Compute `Y = g(X; scale) = scale @ X`.
3939

tensorflow_probability/python/bijectors/scale_matvec_linear_operator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,7 @@
3333
]
3434

3535

36-
# TODO(b/182603117): Enable AutoCompositeTensor once LinearOperators are
37-
# converted to CompositeTensor.
38-
class _ScaleMatvecLinearOperatorBase(bijector.Bijector):
36+
class _ScaleMatvecLinearOperatorBase(bijector.AutoCompositeTensorBijector):
3937
"""Common base class for `ScaleMatvecLinearOperator{Block}`."""
4038

4139
@property
@@ -68,6 +66,7 @@ def _parameter_control_dependencies(self, is_init):
6866
return []
6967

7068

69+
@bijector.auto_composite_tensor_bijector
7170
class ScaleMatvecLinearOperator(_ScaleMatvecLinearOperatorBase):
7271
"""Compute `Y = g(X; scale) = scale @ X`.
7372
@@ -144,6 +143,7 @@ def __init__(self,
144143
name=name)
145144

146145

146+
@bijector.auto_composite_tensor_bijector
147147
class ScaleMatvecLinearOperatorBlock(_ScaleMatvecLinearOperatorBase):
148148
"""Compute `Y = g(X; scale) = scale @ X` for blockwise `X` and `scale`.
149149

tensorflow_probability/python/bijectors/scale_matvec_tril.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from __future__ import print_function
2020

2121
import tensorflow.compat.v2 as tf
22+
from tensorflow_probability.python.bijectors import bijector
2223
from tensorflow_probability.python.bijectors import fill_triangular as fill_triangular_bijector
2324
from tensorflow_probability.python.bijectors import scale_matvec_linear_operator
2425
from tensorflow_probability.python.internal import dtype_util
@@ -31,6 +32,7 @@
3132
]
3233

3334

35+
@bijector.auto_composite_tensor_bijector
3436
class ScaleMatvecTriL(scale_matvec_linear_operator.ScaleMatvecLinearOperator):
3537
"""Compute `Y = g(X; scale) = scale @ X`.
3638

tensorflow_probability/python/experimental/vi/util/trainable_linear_operators.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,8 @@ def build_linear_operator_zeros(
305305
is_self_adjoint=is_square, dtype=dtype)
306306

307307

308-
class _DefaultScaleDiagonal(bijector_lib.Bijector):
308+
@bijector_lib.auto_composite_tensor_bijector
309+
class _DefaultScaleDiagonal(bijector_lib.AutoCompositeTensorBijector):
309310
"""Default bijector for constraining the diagonal of scale matrices."""
310311

311312
def __init__(self):

0 commit comments

Comments
 (0)