Skip to content

Commit 56fa13f

Browse files
Replace tensorflow.python.keras with keras. tensorflow.python.keras is an old copy and is deprecated.
PiperOrigin-RevId: 485943740
1 parent ceb0898 commit 56fa13f

File tree

8 files changed

+12
-32
lines changed

8 files changed

+12
-32
lines changed

tensorflow_model_optimization/python/core/sparsity/keras/BUILD

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ py_strict_library(
5151
visibility = ["//visibility:public"],
5252
deps = [
5353
":prunable_layer",
54+
# keras/engine:base_layer dep1,
5455
# tensorflow dep1,
55-
# python/keras:base_layer tensorflow dep2,
5656
],
5757
)
5858

@@ -104,9 +104,9 @@ py_strict_library(
104104
":pruning_impl",
105105
":pruning_schedule",
106106
":pruning_utils",
107+
# keras/utils:generic_utils dep1,
107108
# numpy dep1,
108-
# tensorflow dep1,
109-
# python/keras/utils:generic_utils tensorflow dep2,
109+
# tensorflow:tensorflow_no_contrib dep1,
110110
"//tensorflow_model_optimization/python/core/keras:compat",
111111
"//tensorflow_model_optimization/python/core/keras:metrics",
112112
"//tensorflow_model_optimization/python/core/keras:utils",

tensorflow_model_optimization/python/core/sparsity/keras/prune_integration_test.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import tensorflow as tf
2222

2323
# TODO(b/139939526): move to public API.
24-
from tensorflow.python.keras import keras_parameterized
2524
from tensorflow_model_optimization.python.core.keras import test_utils as keras_test_utils
2625
from tensorflow_model_optimization.python.core.sparsity.keras import prune
2726
from tensorflow_model_optimization.python.core.sparsity.keras import prune_registry
@@ -37,7 +36,6 @@
3736
ModelCompare = keras_test_utils.ModelCompare
3837

3938

40-
@keras_parameterized.run_all_keras_modes
4139
class PruneIntegrationTest(tf.test.TestCase, parameterized.TestCase,
4240
ModelCompare):
4341

@@ -691,7 +689,6 @@ def testPruneWithPolynomialDecayPastEndStep_PreservesSparsity(
691689
self._check_strip_pruning_matches_original(model, 0.6)
692690

693691

694-
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
695692
class PruneIntegrationCustomTrainingLoopTest(tf.test.TestCase,
696693
parameterized.TestCase):
697694

tensorflow_model_optimization/python/core/sparsity/keras/prune_registry.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@
1414
# ==============================================================================
1515
"""Registry responsible for built-in keras classes."""
1616

17+
from keras.engine import base_layer
1718
import tensorflow as tf
1819

19-
# TODO(b/139939526): move to public API.
20-
from tensorflow.python.keras.engine.base_layer import TensorFlowOpLayer
2120
from tensorflow_model_optimization.python.core.sparsity.keras import prunable_layer
2221

22+
# TODO(b/139939526): move to public API.
23+
2324
layers = tf.keras.layers
2425
layers_compat_v1 = tf.compat.v1.keras.layers
2526

@@ -100,7 +101,7 @@ class PruneRegistry(object):
100101
],
101102
layers.experimental.SyncBatchNormalization: [],
102103
layers.experimental.preprocessing.Rescaling.__class__: [],
103-
TensorFlowOpLayer: [],
104+
base_layer.TensorFlowOpLayer: [],
104105
layers_compat_v1.BatchNormalization: [],
105106
}
106107

tensorflow_model_optimization/python/core/sparsity/keras/prune_test.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,10 @@
1717
import json
1818
import tempfile
1919

20-
from absl.testing import parameterized
2120
import numpy as np
2221
import tensorflow as tf
2322

2423
# TODO(b/139939526): move to public API.
25-
from tensorflow.python.keras import keras_parameterized
2624
from tensorflow_model_optimization.python.core.keras import test_utils as keras_test_utils
2725
from tensorflow_model_optimization.python.core.sparsity.keras import prunable_layer
2826
from tensorflow_model_optimization.python.core.sparsity.keras import prune
@@ -57,7 +55,7 @@ class CustomNonPrunableLayer(layers.Dense):
5755
pass
5856

5957

60-
class PruneTest(test.TestCase, parameterized.TestCase):
58+
class PruneTest(test.TestCase):
6159

6260
INVALID_TO_PRUNE_PARAM_ERROR = ('`prune_low_magnitude` can only prune an '
6361
'object of the following types: '
@@ -202,7 +200,6 @@ def testPruneValidLayersListSuccessful(self):
202200
for layer, pruned_layer in zip(model_layers, pruned_layers):
203201
self._validate_pruned_layer(layer, pruned_layer)
204202

205-
@keras_parameterized.run_all_keras_modes
206203
def testPruneInferenceWorks_PruningStepCallbackNotRequired(self):
207204
model = prune.prune_low_magnitude(
208205
keras.Sequential([

tensorflow_model_optimization/python/core/sparsity/keras/pruning_callbacks_test.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,10 @@
1717
import os
1818
import tempfile
1919

20-
from absl.testing import parameterized
2120
import numpy as np
2221
import tensorflow as tf
2322

2423
# TODO(b/139939526): move to public API.
25-
from tensorflow.python.keras import keras_parameterized
2624
from tensorflow_model_optimization.python.core.keras import test_utils as keras_test_utils
2725
from tensorflow_model_optimization.python.core.sparsity.keras import prune
2826
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_callbacks
@@ -31,7 +29,7 @@
3129
errors_impl = tf.errors
3230

3331

34-
class PruneCallbacksTest(tf.test.TestCase, parameterized.TestCase):
32+
class PruneCallbacksTest(tf.test.TestCase):
3533

3634
_BATCH_SIZE = 20
3735

@@ -55,7 +53,6 @@ def _pruned_model_setup(self, custom_training_loop=False):
5553
pruned_model.compile(loss=loss, optimizer=optimizer, metrics=['accuracy'])
5654
return pruned_model, x_train, y_train
5755

58-
@keras_parameterized.run_all_keras_modes
5956
def testUpdatePruningStepsAndLogsSummaries(self):
6057
log_dir = tempfile.mkdtemp()
6158
pruned_model, x_train, y_train = self._pruned_model_setup()
@@ -77,7 +74,6 @@ def testUpdatePruningStepsAndLogsSummaries(self):
7774
self._assertLogsExist(log_dir)
7875

7976
# This style of custom training loop isn't available in graph mode.
80-
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
8177
def testUpdatePruningStepsAndLogsSummaries_CustomTrainingLoop(self):
8278
log_dir = tempfile.mkdtemp()
8379
pruned_model, loss, optimizer, x_train, y_train = self._pruned_model_setup(
@@ -116,7 +112,6 @@ def testUpdatePruningStepsAndLogsSummaries_CustomTrainingLoop(self):
116112
3, tf.keras.backend.get_value(pruned_model.layers[1].pruning_step))
117113
self._assertLogsExist(log_dir)
118114

119-
@keras_parameterized.run_all_keras_modes
120115
def testUpdatePruningStepsAndLogsSummaries_RunInference(self):
121116
pruned_model, _, _, x_train, _ = self._pruned_model_setup(
122117
custom_training_loop=True)
@@ -128,7 +123,6 @@ def testUpdatePruningStepsAndLogsSummaries_RunInference(self):
128123
self.assertEqual(
129124
-1, tf.keras.backend.get_value(pruned_model.layers[1].pruning_step))
130125

131-
@keras_parameterized.run_all_keras_modes
132126
def testPruneTrainingRaisesError_PruningStepCallbackMissing(self):
133127
pruned_model, x_train, y_train = self._pruned_model_setup()
134128

@@ -137,7 +131,6 @@ def testPruneTrainingRaisesError_PruningStepCallbackMissing(self):
137131
pruned_model.fit(x_train, y_train)
138132

139133
# This style of custom training loop isn't available in graph mode.
140-
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
141134
def testPruneTrainingLoopRaisesError_PruningStepCallbackMissing_CustomTrainingLoop(
142135
self):
143136
pruned_model, _, _, x_train, _ = self._pruned_model_setup(
@@ -149,7 +142,6 @@ def testPruneTrainingLoopRaisesError_PruningStepCallbackMissing_CustomTrainingLo
149142
with tf.GradientTape():
150143
pruned_model(inp, training=True)
151144

152-
@keras_parameterized.run_all_keras_modes
153145
def testPruningSummariesRaisesError_LogDirNotNonEmptyString(self):
154146
with self.assertRaises(ValueError):
155147
pruning_callbacks.PruningSummaries(log_dir='')

tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl_test.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,9 @@
2525
import tensorflow as tf
2626

2727
# TODO(b/139939526): move to public API.
28-
from tensorflow.python.keras import keras_parameterized
2928
from tensorflow_model_optimization.python.core.keras import compat
3029
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_impl
3130
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule
32-
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_utils
3331

3432
K = tf.keras.backend
3533
dtypes = tf.dtypes
@@ -43,7 +41,6 @@ def assign_add(ref, value):
4341
return ref.assign_add(value)
4442

4543

46-
@keras_parameterized.run_all_keras_modes
4744
class PruningTest(test.TestCase, parameterized.TestCase):
4845

4946
def setUp(self):

tensorflow_model_optimization/python/core/sparsity/keras/pruning_schedule_test.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import tensorflow as tf
1919

2020
# TODO(b/139939526): move to public API.
21-
from tensorflow.python.keras import keras_parameterized
2221
from tensorflow_model_optimization.python.core.keras import compat
2322
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule
2423

@@ -155,7 +154,6 @@ def testSparsityValueIsValid(self, schedule_type):
155154

156155
# Tests to ensure begin_step, end_step, frequency are used correctly.
157156

158-
@keras_parameterized.run_all_keras_modes
159157
@parameterized.named_parameters(
160158
{
161159
'testcase_name': 'ConstantSparsity',
@@ -189,7 +187,6 @@ def testPrunesOnlyInBeginEndStepRange(self, schedule_type):
189187
self.assertFalse(self.evaluate(sparsity(step_201))[0])
190188
self.assertFalse(self.evaluate(sparsity(step_210))[0])
191189

192-
@keras_parameterized.run_all_keras_modes
193190
@parameterized.named_parameters(
194191
{
195192
'testcase_name': 'ConstantSparsity',
@@ -216,7 +213,6 @@ def testOnlyPrunesAtValidFrequencySteps(self, schedule_type):
216213

217214
class ConstantSparsityTest(tf.test.TestCase, parameterized.TestCase):
218215

219-
@keras_parameterized.run_all_keras_modes
220216
def testPrunesForeverIfEndStepIsNegativeOne(self):
221217
sparsity = pruning_schedule.ConstantSparsity(0.5, 0, -1, 10)
222218

@@ -230,7 +226,6 @@ def testPrunesForeverIfEndStepIsNegativeOne(self):
230226
self.assertAllClose(0.5, self.evaluate(sparsity(step_10000))[1])
231227
self.assertAllClose(0.5, self.evaluate(sparsity(step_100000000))[1])
232228

233-
@keras_parameterized.run_all_keras_modes
234229
def testPrunesWithConstantSparsity(self):
235230
sparsity = pruning_schedule.ConstantSparsity(0.5, 100, 200, 10)
236231

@@ -263,7 +258,6 @@ def testRaisesErrorIfEndStepIsNegative(self):
263258
with self.assertRaises(ValueError):
264259
pruning_schedule.PolynomialDecay(0.4, 0.8, 10, -1)
265260

266-
@keras_parameterized.run_all_keras_modes
267261
def testPolynomialDecay_PrunesCorrectly(self):
268262
sparsity = pruning_schedule.PolynomialDecay(0.2, 0.8, 100, 110, 3, 2)
269263

tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,14 @@
2020
from __future__ import print_function
2121

2222
import inspect
23+
2324
# import g3
25+
26+
from keras.utils import generic_utils
2427
import numpy as np
2528
import tensorflow as tf
2629

2730
# TODO(b/139939526): update to use public API.
28-
from tensorflow.python.keras.utils import generic_utils
2931
from tensorflow_model_optimization.python.core.keras import compat as tf_compat
3032
from tensorflow_model_optimization.python.core.keras import metrics
3133
from tensorflow_model_optimization.python.core.keras import utils

0 commit comments

Comments
 (0)