Skip to content

Commit 58a7885

Browse files
srvasudetensorflower-gardener
authored andcommitted
Remove some dependencies in internal libraries on numpy backend.
PiperOrigin-RevId: 473083253
1 parent 4c5c0f9 commit 58a7885

File tree

4 files changed

+3
-7
lines changed

4 files changed

+3
-7
lines changed

tensorflow_probability/python/internal/BUILD

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,6 @@ multi_substrate_py_library(
344344
srcs = ["implementation_selection.py"],
345345
deps = [
346346
# tensorflow dep,
347-
"//tensorflow_probability/python/internal/backend/numpy",
348347
],
349348
)
350349

@@ -757,7 +756,6 @@ multi_substrate_py_library(
757756
# tensorflow dep,
758757
"//tensorflow_probability/python/bijectors:bijector",
759758
"//tensorflow_probability/python/internal:empirical_statistical_testing",
760-
"//tensorflow_probability/python/internal/backend/numpy",
761759
"//tensorflow_probability/python/util:deferred_tensor",
762760
"//tensorflow_probability/python/util:seed_stream",
763761
],

tensorflow_probability/python/internal/test_util.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
from tensorflow_probability.python.internal import empirical_statistical_testing
3535
from tensorflow_probability.python.internal import samplers
3636
from tensorflow_probability.python.internal import test_combinations
37-
from tensorflow_probability.python.internal.backend.numpy import ops
3837
from tensorflow_probability.python.util.deferred_tensor import DeferredTensor
3938
from tensorflow_probability.python.util.deferred_tensor import TransformedVariable
4039
from tensorflow_probability.python.util.seed_stream import SeedStream
@@ -868,7 +867,7 @@ def f(test_fn_or_class):
868867
"""Decorator."""
869868
if JAX_MODE:
870869
return test_fn_or_class
871-
if tf.Variable != ops.NumpyVariable:
870+
if not NUMPY_MODE:
872871
return test_fn_or_class
873872

874873
reason = 'Test disabled for Numpy missing functionality: {}'.format(
@@ -912,7 +911,7 @@ def tf_tape_safety_test(test_fn):
912911
"""Only run a test of TF2 tape safety against the TF backend."""
913912

914913
def new_test(self, *args, **kwargs):
915-
if JAX_MODE or (tf.Variable == ops.NumpyVariable):
914+
if JAX_MODE or NUMPY_MODE:
916915
self.skipTest('Tape-safety tests are only run against TensorFlow.')
917916
return test_fn(self, *args, **kwargs)
918917

tensorflow_probability/python/math/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,6 @@ multi_substrate_py_library(
181181
deps = [
182182
# tensorflow dep,
183183
"//tensorflow_probability/python/internal:tensor_util",
184-
"//tensorflow_probability/python/internal/backend/numpy:tf_inspect",
185184
],
186185
)
187186

tensorflow_probability/python/math/gradient.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import tensorflow.compat.v2 as tf
1818

1919
from tensorflow_probability.python.internal import tensor_util
20-
from tensorflow_probability.python.internal.backend.numpy import tf_inspect
20+
from tensorflow.python.util import tf_inspect # pylint: disable=g-direct-tensorflow-import
2121

2222

2323
__all__ = [

0 commit comments

Comments
 (0)