Skip to content

Commit 5ec3bf1

Browse files
emilyfertigtensorflower-gardener
authored andcommitted
Remove tfp.distributions dependency from tfp.experimental.composite_tensor and import it only when needed (reduces integration test footprint).
PiperOrigin-RevId: 476165470
1 parent 0bd5811 commit 5ec3bf1

File tree

4 files changed

+4
-4
lines changed

4 files changed

+4
-4
lines changed

tensorflow_probability/python/experimental/util/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,6 @@ py_library(
156156
srcs = ["composite_tensor.py"],
157157
deps = [
158158
# tensorflow dep,
159-
"//tensorflow_probability/python/distributions",
160159
],
161160
)
162161

tensorflow_probability/python/experimental/util/composite_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import six
1919

2020
import tensorflow.compat.v2 as tf
21-
from tensorflow_probability.python import distributions
2221
from tensorflow_probability.python.internal import tensor_util
2322
from tensorflow.python.framework.composite_tensor import CompositeTensor # pylint: disable=g-direct-tensorflow-import
2423
from tensorflow.python.saved_model import nested_structure_coder # pylint: disable=g-direct-tensorflow-import
@@ -160,6 +159,7 @@ def _find_clsid(clsid):
160159
pkg, cls = clsid
161160
if clsid not in _registry:
162161
if pkg.startswith('tensorflow_probability.') and '.distributions' in pkg:
162+
from tensorflow_probability.python import distributions # pylint: disable=g-import-not-at-top
163163
dist_cls = getattr(distributions, cls)
164164
if (inspect.isclass(dist_cls) and
165165
issubclass(dist_cls, distributions.Distribution)):

tensorflow_probability/python/experimental/vi/util/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ multi_substrate_py_library(
4848
"//tensorflow_probability/python/internal:dtype_util",
4949
"//tensorflow_probability/python/internal:prefer_static",
5050
"//tensorflow_probability/python/internal:samplers",
51+
"//tensorflow_probability/python/internal:trainable_state_util",
5152
],
5253
)
5354

tensorflow_probability/python/experimental/vi/util/trainable_linear_operators.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def _trainable_linear_operator_tril(
218218
scale_tril_bijector = fill_scale_tril.FillScaleTriL(
219219
diag_bijector, diag_shift=tf.zeros([], dtype=dtype))
220220
scale_tril = yield trainable_state_util.Parameter(
221-
init_fn=lambda seed: scale_tril_bijector( # pylint: disable=g-long-lambda
221+
init_fn=lambda seed: scale_tril_bijector.forward( # pylint: disable=g-long-lambda
222222
samplers.normal(
223223
mean=0.,
224224
stddev=scale_initializer,
@@ -262,7 +262,7 @@ def _trainable_linear_operator_diag(
262262

263263
diag_bijector = diag_bijector or _DefaultScaleDiagonal()
264264
scale_diag = yield trainable_state_util.Parameter(
265-
init_fn=lambda seed: diag_bijector( # pylint: disable=g-long-lambda
265+
init_fn=lambda seed: diag_bijector.forward( # pylint: disable=g-long-lambda
266266
samplers.normal(
267267
mean=0.,
268268
stddev=scale_initializer,

0 commit comments

Comments
 (0)