Skip to content

Commit b02397e

Browse files
tf-transform-teamtfx-copybara
authored andcommitted
Explicitly set TF_USE_LEGACY_KERAS=1 when importing Keras.
Tensorflow Transform is only compatible with Keras 2. PiperOrigin-RevId: 687967427
1 parent 257a24b commit b02397e

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

tensorflow_transform/keras_lib.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,22 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""Imports keras 2."""
15+
import os
16+
17+
from absl import logging
1518
import tensorflow as tf
1619

20+
if 'TF_USE_LEGACY_KERAS' not in os.environ:
21+
# Make sure we are using Keras 2.
22+
os.environ['TF_USE_LEGACY_KERAS'] = '1'
23+
elif os.environ['TF_USE_LEGACY_KERAS'] not in ('true', 'True', '1'):
24+
logging.warning(
25+
'TF_USE_LEGACY_KERAS is set to %s, which will not use Keras 2. Tensorflow'
26+
' Transform is only compatible with Keras 2. Please set'
27+
' TF_USE_LEGACY_KERAS=1.',
28+
os.environ['TF_USE_LEGACY_KERAS'],
29+
)
30+
1731
version_fn = getattr(tf.keras, 'version', None)
1832
if version_fn and version_fn().startswith('3.'):
1933
# `tf.keras` points to `keras 3`, so use `tf_keras` package

0 commit comments

Comments
 (0)