Skip to content

Commit 01ededf

Browse files
lingvo-botcopybara-github
authored andcommitted
Update xla_sharding import path to new location
We are moving the TensorFlow APIs outside of XLA and will remove the old path soon. PiperOrigin-RevId: 477859491
1 parent 43c305a commit 01ededf

File tree

3 files changed

+18
-3
lines changed

3 files changed

+18
-3
lines changed

lingvo/core/gshard_layers.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,12 @@
2828

2929
# pylint: disable=g-direct-tensorflow-import
3030
from tensorflow.compiler.tf2xla.python import xla
31-
from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding
31+
# pylint: disable=g-import-not-at-top
32+
try:
33+
from tensorflow.python.compiler.xla.experimental import xla_sharding
34+
except ImportError:
35+
# OSS backward compatibility, can be removed when TF is updated.
36+
from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding
3237
# pylint: enable=g-direct-tensorflow-import
3338

3439
Split = gshard_utils.Split

lingvo/core/gshard_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,12 @@
2525
import sentencepiece as sentencepiece_processor
2626
# pylint: disable=g-direct-tensorflow-import
2727
from tensorflow.compiler.xla import xla_data_pb2
28-
from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding
28+
# pylint: disable=g-import-not-at-top
29+
try:
30+
from tensorflow.python.compiler.xla.experimental import xla_sharding
31+
except ImportError:
32+
# OSS backward compatibility, can be removed when TF is updated.
33+
from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding
2934
# pylint: enable=g-direct-tensorflow-import
3035

3136
ThreadLocalStack = thread_local_utils.ThreadLocalStack

lingvo/core/var_tmp_wrappers.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,12 @@
2222
from lingvo import compat as tf
2323

2424
# pylint: disable=g-direct-tensorflow-import
25-
from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding
25+
# pylint: disable=g-import-not-at-top
26+
try:
27+
from tensorflow.python.compiler.xla.experimental import xla_sharding
28+
except ImportError:
29+
# OSS backward compatibility, can be removed when TF is updated.
30+
from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding
2631
# pylint: enable=g-direct-tensorflow-import
2732

2833

0 commit comments

Comments
 (0)