Skip to content

Commit dc3e232

Browse files
abatterytensorflower-gardener
authored andcommitted
Force to use the keras v2 version to resolve breakages in OSS
PiperOrigin-RevId: 604160487
1 parent b0eea48 commit dc3e232

File tree

4 files changed

+8
-39
lines changed

4 files changed

+8
-39
lines changed

tensorflow_model_optimization/python/core/keras/compat.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,25 +19,23 @@
1919
from __future__ import print_function
2020

2121
import collections
22+
import os
2223
import weakref
2324

2425
import tensorflow as tf
2526

2627

2728
def _get_keras_instance():
28-
from pkg_resources import parse_version
29-
30-
required_tensorflow_version = '2.16.0'
31-
if parse_version(tf.__version__) < parse_version(required_tensorflow_version):
32-
return tf.keras
29+
# Keep using keras-2 (tf-keras) rather than keras-3 (keras).
30+
os.environ['TF_USE_LEGACY_KERAS'] = '1'
3331

32+
# Use Keras 2.
3433
version_fn = getattr(tf.keras, 'version', None)
3534
if version_fn and version_fn().startswith('3.'):
36-
try:
37-
import tf_keras as keras
38-
except ImportError:
39-
pass
40-
return tf.keras
35+
import tf_keras as keras_internal # pylint: disable=g-import-not-at-top,unused-import
36+
else:
37+
keras_internal = tf.keras
38+
return keras_internal
4139

4240

4341
keras = _get_keras_instance()

tensorflow_model_optimization/python/core/quantization/keras/BUILD

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,6 @@ py_strict_test(
326326
srcs = ["quantize_models_test.py"],
327327
flaky = True,
328328
python_version = "PY3",
329-
shard_count = 10,
330329
deps = [
331330
":quantize",
332331
":utils",
@@ -343,8 +342,6 @@ py_strict_test(
343342
size = "large",
344343
srcs = ["quantize_functional_test.py"],
345344
python_version = "PY3",
346-
# To match parallel runs of run_all_keras_modes.
347-
shard_count = 4,
348345
deps = [
349346
":quantize",
350347
":utils",

tensorflow_model_optimization/python/core/sparsity/keras/prune_registry.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121

2222

2323
try:
24-
# OSS case.
25-
import keras # pylint: disable=g-import-not-at-top
2624
if hasattr(keras, 'src'):
2725
# Path as seen in pip packages as of TF/Keras 2.13.
2826
from keras.src.engine import base_layer # pylint: disable=g-import-not-at-top,g-importing-member

tensorflow_model_optimization/python/examples/sparsity/keras/mnist/BUILD

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,30 +11,6 @@ filegroup(
1111
srcs = glob(["**"]),
1212
)
1313

14-
py_strict_binary(
15-
name = "mnist_estimator",
16-
srcs = [
17-
"dataset.py",
18-
"mnist_estimator.py",
19-
],
20-
python_version = "PY3",
21-
deps = [
22-
# absl/flags dep1,
23-
# google/protobuf:use_fast_cpp_protos dep1, # Automatically added
24-
# numpy dep1,
25-
# six dep1,
26-
# tensorflow dep1,
27-
# tensorflow:tensorflow_compat_v1_estimator dep1,
28-
"//tensorflow_model_optimization/python/core/keras:compat",
29-
"//tensorflow_model_optimization/python/core/sparsity/keras:estimator_utils",
30-
"//tensorflow_model_optimization/python/core/sparsity/keras:prune",
31-
"//tensorflow_model_optimization/python/core/sparsity/keras:pruning_schedule",
32-
"//third_party/tensorflow_models/official/common:distribute_utils",
33-
"//third_party/tensorflow_models/official/r1/utils/logs:hooks_helper",
34-
"//third_party/tensorflow_models/official/utils",
35-
],
36-
)
37-
3814
py_strict_binary(
3915
name = "mnist_cnn",
4016
srcs = [

0 commit comments

Comments
 (0)