Skip to content

Commit c69428c

Browse files
SiegeLordExtensorflower-gardener
authored andcommitted
Move composite_tensor.py into tfp.experimental.util.
This is so I can add a tfp.experimental.math without triggering the usual Python import shadowing issue. PiperOrigin-RevId: 390180664
1 parent 3069fc4 commit c69428c

File tree

6 files changed

+35
-29
lines changed

6 files changed

+35
-29
lines changed

tensorflow_probability/python/experimental/BUILD

Lines changed: 2 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ multi_substrate_py_library(
3939
],
4040
srcs_version = "PY3",
4141
substrates_omit_deps = [
42-
":composite_tensor",
4342
"//tensorflow_probability/python/experimental/auto_batching",
4443
"//tensorflow_probability/python/experimental/linalg",
4544
"//tensorflow_probability/python/experimental/marginalize",
@@ -49,9 +48,9 @@ multi_substrate_py_library(
4948
"//tensorflow_probability/python/experimental/substrates",
5049
"//tensorflow_probability/python/experimental/vi",
5150
"//tensorflow_probability/python/internal:auto_composite_tensor",
51+
"//tensorflow_probability/python/experimental/util:composite_tensor",
5252
],
5353
deps = [
54-
":composite_tensor",
5554
"//tensorflow_probability/python/experimental/auto_batching",
5655
"//tensorflow_probability/python/experimental/bijectors",
5756
"//tensorflow_probability/python/experimental/distribute",
@@ -66,31 +65,9 @@ multi_substrate_py_library(
6665
"//tensorflow_probability/python/experimental/sts_gibbs",
6766
"//tensorflow_probability/python/experimental/substrates",
6867
"//tensorflow_probability/python/experimental/util",
68+
"//tensorflow_probability/python/experimental/util:composite_tensor",
6969
"//tensorflow_probability/python/experimental/vi",
7070
"//tensorflow_probability/python/internal:all_util",
7171
"//tensorflow_probability/python/internal:auto_composite_tensor",
7272
],
7373
)
74-
75-
py_library(
76-
name = "composite_tensor",
77-
srcs = ["composite_tensor.py"],
78-
srcs_version = "PY3",
79-
deps = [
80-
# tensorflow dep,
81-
"//tensorflow_probability/python/distributions",
82-
],
83-
)
84-
85-
py_test(
86-
name = "composite_tensor_test",
87-
srcs = ["composite_tensor_test.py"],
88-
python_version = "PY3",
89-
srcs_version = "PY3",
90-
deps = [
91-
# numpy dep,
92-
# tensorflow dep,
93-
"//tensorflow_probability",
94-
"//tensorflow_probability/python/internal:test_util",
95-
],
96-
)

tensorflow_probability/python/experimental/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@
4545
from tensorflow_probability.python.experimental import substrates
4646
from tensorflow_probability.python.experimental import util
4747
from tensorflow_probability.python.experimental import vi
48-
from tensorflow_probability.python.experimental.composite_tensor import as_composite
49-
from tensorflow_probability.python.experimental.composite_tensor import register_composite
48+
from tensorflow_probability.python.experimental.util.composite_tensor import as_composite
49+
from tensorflow_probability.python.experimental.util.composite_tensor import register_composite
5050
from tensorflow_probability.python.internal import all_util
5151
from tensorflow_probability.python.internal.auto_composite_tensor import auto_composite_tensor
5252
from tensorflow_probability.python.internal.auto_composite_tensor import AutoCompositeTensor

tensorflow_probability/python/experimental/util/BUILD

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,12 @@ multi_substrate_py_library(
3737
numpy_omit_deps = [":jit_public_methods"],
3838
srcs_version = "PY3",
3939
substrates_omit_deps = [
40+
":composite_tensor",
4041
":deferred_module",
4142
":trainable",
4243
],
4344
deps = [
45+
":composite_tensor",
4446
":deferred_module",
4547
":jit_public_methods",
4648
":trainable",
@@ -133,3 +135,26 @@ py_test(
133135
"//tensorflow_probability/python/internal:test_util",
134136
],
135137
)
138+
139+
py_library(
140+
name = "composite_tensor",
141+
srcs = ["composite_tensor.py"],
142+
srcs_version = "PY3",
143+
deps = [
144+
# tensorflow dep,
145+
"//tensorflow_probability/python/distributions",
146+
],
147+
)
148+
149+
py_test(
150+
name = "composite_tensor_test",
151+
srcs = ["composite_tensor_test.py"],
152+
python_version = "PY3",
153+
srcs_version = "PY3",
154+
deps = [
155+
# numpy dep,
156+
# tensorflow dep,
157+
"//tensorflow_probability",
158+
"//tensorflow_probability/python/internal:test_util",
159+
],
160+
)

tensorflow_probability/python/experimental/util/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from __future__ import division
1919
from __future__ import print_function
2020

21+
from tensorflow_probability.python.experimental.util.composite_tensor import as_composite
22+
from tensorflow_probability.python.experimental.util.composite_tensor import register_composite
2123
from tensorflow_probability.python.experimental.util.deferred_module import DeferredModule
2224
from tensorflow_probability.python.experimental.util.jit_public_methods import DEFAULT_METHODS_EXCLUDED_FROM_JIT
2325
from tensorflow_probability.python.experimental.util.jit_public_methods import JitPublicMethods
@@ -26,10 +28,12 @@
2628

2729

2830
_allowed_symbols = [
31+
'as_composite',
2932
'DEFAULT_METHODS_EXCLUDED_FROM_JIT',
3033
'DeferredModule',
3134
'JitPublicMethods',
32-
'make_trainable'
35+
'make_trainable',
36+
'register_composite',
3337
]
3438

3539
all_util.remove_undocumented(__name__, _allowed_symbols)

tensorflow_probability/python/experimental/composite_tensor_test.py renamed to tensorflow_probability/python/experimental/util/composite_tensor_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import tensorflow.compat.v2 as tf
2727
import tensorflow_probability as tfp
2828

29-
from tensorflow_probability.python.experimental.composite_tensor import _registry as clsid_registry
29+
from tensorflow_probability.python.experimental.util.composite_tensor import _registry as clsid_registry
3030
from tensorflow_probability.python.internal import test_util as tfp_test_util
3131
from tensorflow.python.framework import test_util # pylint: disable=g-direct-tensorflow-import
3232

0 commit comments

Comments
 (0)