File tree Expand file tree Collapse file tree 4 files changed +4
-4
lines changed
tensorflow_probability/python/experimental Expand file tree Collapse file tree 4 files changed +4
-4
lines changed Original file line number Diff line number Diff line change @@ -156,7 +156,6 @@ py_library(
156
156
srcs = ["composite_tensor.py" ],
157
157
deps = [
158
158
# tensorflow dep,
159
- "//tensorflow_probability/python/distributions" ,
160
159
],
161
160
)
162
161
Original file line number Diff line number Diff line change 18
18
import six
19
19
20
20
import tensorflow .compat .v2 as tf
21
- from tensorflow_probability .python import distributions
22
21
from tensorflow_probability .python .internal import tensor_util
23
22
from tensorflow .python .framework .composite_tensor import CompositeTensor # pylint: disable=g-direct-tensorflow-import
24
23
from tensorflow .python .saved_model import nested_structure_coder # pylint: disable=g-direct-tensorflow-import
@@ -160,6 +159,7 @@ def _find_clsid(clsid):
160
159
pkg , cls = clsid
161
160
if clsid not in _registry :
162
161
if pkg .startswith ('tensorflow_probability.' ) and '.distributions' in pkg :
162
+ from tensorflow_probability .python import distributions # pylint: disable=g-import-not-at-top
163
163
dist_cls = getattr (distributions , cls )
164
164
if (inspect .isclass (dist_cls ) and
165
165
issubclass (dist_cls , distributions .Distribution )):
Original file line number Diff line number Diff line change @@ -48,6 +48,7 @@ multi_substrate_py_library(
48
48
"//tensorflow_probability/python/internal:dtype_util" ,
49
49
"//tensorflow_probability/python/internal:prefer_static" ,
50
50
"//tensorflow_probability/python/internal:samplers" ,
51
+ "//tensorflow_probability/python/internal:trainable_state_util" ,
51
52
],
52
53
)
53
54
Original file line number Diff line number Diff line change @@ -218,7 +218,7 @@ def _trainable_linear_operator_tril(
218
218
scale_tril_bijector = fill_scale_tril .FillScaleTriL (
219
219
diag_bijector , diag_shift = tf .zeros ([], dtype = dtype ))
220
220
scale_tril = yield trainable_state_util .Parameter (
221
- init_fn = lambda seed : scale_tril_bijector ( # pylint: disable=g-long-lambda
221
+ init_fn = lambda seed : scale_tril_bijector . forward ( # pylint: disable=g-long-lambda
222
222
samplers .normal (
223
223
mean = 0. ,
224
224
stddev = scale_initializer ,
@@ -262,7 +262,7 @@ def _trainable_linear_operator_diag(
262
262
263
263
diag_bijector = diag_bijector or _DefaultScaleDiagonal ()
264
264
scale_diag = yield trainable_state_util .Parameter (
265
- init_fn = lambda seed : diag_bijector ( # pylint: disable=g-long-lambda
265
+ init_fn = lambda seed : diag_bijector . forward ( # pylint: disable=g-long-lambda
266
266
samplers .normal (
267
267
mean = 0. ,
268
268
stddev = scale_initializer ,
You can’t perform that action at this time.
0 commit comments