File tree Expand file tree Collapse file tree 4 files changed +8
-39
lines changed
tensorflow_model_optimization/python
examples/sparsity/keras/mnist Expand file tree Collapse file tree 4 files changed +8
-39
lines changed Original file line number Diff line number Diff line change 19
19
from __future__ import print_function
20
20
21
21
import collections
22
+ import os
22
23
import weakref
23
24
24
25
import tensorflow as tf
25
26
26
27
27
28
def _get_keras_instance ():
28
- from pkg_resources import parse_version
29
-
30
- required_tensorflow_version = '2.16.0'
31
- if parse_version (tf .__version__ ) < parse_version (required_tensorflow_version ):
32
- return tf .keras
29
+ # Keep using keras-2 (tf-keras) rather than keras-3 (keras).
30
+ os .environ ['TF_USE_LEGACY_KERAS' ] = '1'
33
31
32
+ # Use Keras 2.
34
33
version_fn = getattr (tf .keras , 'version' , None )
35
34
if version_fn and version_fn ().startswith ('3.' ):
36
- try :
37
- import tf_keras as keras
38
- except ImportError :
39
- pass
40
- return tf .keras
35
+ import tf_keras as keras_internal # pylint: disable=g-import-not-at-top,unused-import
36
+ else :
37
+ keras_internal = tf .keras
38
+ return keras_internal
41
39
42
40
43
41
keras = _get_keras_instance ()
Original file line number Diff line number Diff line change @@ -326,7 +326,6 @@ py_strict_test(
326
326
srcs = ["quantize_models_test.py" ],
327
327
flaky = True ,
328
328
python_version = "PY3" ,
329
- shard_count = 10 ,
330
329
deps = [
331
330
":quantize" ,
332
331
":utils" ,
@@ -343,8 +342,6 @@ py_strict_test(
343
342
size = "large" ,
344
343
srcs = ["quantize_functional_test.py" ],
345
344
python_version = "PY3" ,
346
- # To match parallel runs of run_all_keras_modes.
347
- shard_count = 4 ,
348
345
deps = [
349
346
":quantize" ,
350
347
":utils" ,
Original file line number Diff line number Diff line change 21
21
22
22
23
23
try :
24
- # OSS case.
25
- import keras # pylint: disable=g-import-not-at-top
26
24
if hasattr (keras , 'src' ):
27
25
# Path as seen in pip packages as of TF/Keras 2.13.
28
26
from keras .src .engine import base_layer # pylint: disable=g-import-not-at-top,g-importing-member
Original file line number Diff line number Diff line change @@ -11,30 +11,6 @@ filegroup(
11
11
srcs = glob (["**" ]),
12
12
)
13
13
14
- py_strict_binary (
15
- name = "mnist_estimator" ,
16
- srcs = [
17
- "dataset.py" ,
18
- "mnist_estimator.py" ,
19
- ],
20
- python_version = "PY3" ,
21
- deps = [
22
- # absl/flags dep1,
23
- # google/protobuf:use_fast_cpp_protos dep1, # Automatically added
24
- # numpy dep1,
25
- # six dep1,
26
- # tensorflow dep1,
27
- # tensorflow:tensorflow_compat_v1_estimator dep1,
28
- "//tensorflow_model_optimization/python/core/keras:compat" ,
29
- "//tensorflow_model_optimization/python/core/sparsity/keras:estimator_utils" ,
30
- "//tensorflow_model_optimization/python/core/sparsity/keras:prune" ,
31
- "//tensorflow_model_optimization/python/core/sparsity/keras:pruning_schedule" ,
32
- "//third_party/tensorflow_models/official/common:distribute_utils" ,
33
- "//third_party/tensorflow_models/official/r1/utils/logs:hooks_helper" ,
34
- "//third_party/tensorflow_models/official/utils" ,
35
- ],
36
- )
37
-
38
14
py_strict_binary (
39
15
name = "mnist_cnn" ,
40
16
srcs = [
You can’t perform that action at this time.
0 commit comments