Skip to content

Commit ff464f9

Browse files
alanchiaotensorflower-gardener
authored andcommitted
Migrate majority of pruning API to utilize the tf. public API as opposed to internal versions of those public APIs. Make related updates to support/conform to TF 2.X.
Complete usage of the tf. public API is the only way to ensure that 1.X and 2.X behavior are not mixed and prevents breakages from occurring from TensorFlow file location moves. PiperOrigin-RevId: 285841197
1 parent 256658f commit ff464f9

22 files changed

+515
-392
lines changed

tensorflow_model_optimization/python/core/keras/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,5 @@ py_library(
1515
deps = [
1616
# numpy dep1,
1717
# tensorflow dep1,
18-
# python/keras tensorflow dep2,
1918
],
2019
)

tensorflow_model_optimization/python/core/keras/test_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,16 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515
"""Test utilities for generating, saving, and evaluating models."""
16+
# TODO(tf-mot): dedup and migrate to testing/ directory.
1617

1718
import numpy as np
1819
import tensorflow as tf
1920

20-
from tensorflow.python import keras
21-
l = keras.layers
21+
l = tf.keras.layers
2222

2323

2424
def build_simple_dense_model():
25-
return keras.Sequential([
25+
return tf.keras.Sequential([
2626
l.Dense(8, activation='relu', input_shape=(10,)),
2727
l.Dense(5, activation='softmax')
2828
])

tensorflow_model_optimization/python/core/sparsity/keras/BUILD

Lines changed: 7 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,7 @@ py_library(
2828
deps = [
2929
":pruning_schedule",
3030
":pruning_wrapper",
31-
# python:util tensorflow dep2,
32-
# python/keras tensorflow dep2,
33-
# python/keras:engine tensorflow dep2,
34-
# python/keras:generic_utils tensorflow dep2,
31+
# tensorflow dep1,
3532
],
3633
)
3734

@@ -53,7 +50,6 @@ py_library(
5350
deps = [
5451
":prunable_layer",
5552
# tensorflow dep1,
56-
# python/keras:layers_base tensorflow dep2,
5753
],
5854
)
5955

@@ -64,10 +60,7 @@ py_library(
6460
visibility = ["//visibility:public"],
6561
deps = [
6662
# six dep1,
67-
# python:constant_op tensorflow dep2,
68-
# python:dtypes tensorflow dep2,
69-
# python:framework_ops tensorflow dep2,
70-
# python:math_ops tensorflow dep2,
63+
# tensorflow dep1,
7164
],
7265
)
7366

@@ -81,6 +74,7 @@ py_test(
8174
":pruning_schedule",
8275
# absl/testing:parameterized dep1,
8376
# tensorflow dep1,
77+
# python/keras tensorflow dep2,
8478
],
8579
)
8680

@@ -95,16 +89,8 @@ py_library(
9589
":pruning_impl",
9690
":pruning_schedule",
9791
# numpy dep1,
98-
# python:check_ops tensorflow dep2,
99-
# python:control_flow_ops tensorflow dep2,
100-
# python:dtypes tensorflow dep2,
101-
# python:framework_ops tensorflow dep2,
102-
# python:variables tensorflow dep2,
103-
# python/keras:backend tensorflow dep2,
104-
# python/keras:base_layer tensorflow dep2,
92+
# tensorflow dep1,
10593
# python/keras:generic_utils tensorflow dep2,
106-
# python/keras:initializers tensorflow dep2,
107-
# python/keras:layers_base tensorflow dep2,
10894
# python/keras:tf_utils tensorflow dep2,
10995
],
11096
)
@@ -118,9 +104,6 @@ py_library(
118104
":pruning_wrapper",
119105
# numpy dep1,
120106
# tensorflow dep1,
121-
# python:math_ops tensorflow dep2,
122-
# python/keras:backend tensorflow dep2,
123-
# python/keras:callbacks tensorflow dep2,
124107
],
125108
)
126109

@@ -149,15 +132,8 @@ py_library(
149132
deps = [
150133
":pruning_utils",
151134
# tensorflow dep1,
152-
# python:array_ops tensorflow dep2,
153-
# python:control_flow_ops tensorflow dep2,
154-
# python:dtypes tensorflow dep2,
155-
# python:framework_ops tensorflow dep2,
156-
# python:math_ops tensorflow dep2,
157-
# python:nn_ops tensorflow dep2,
158135
# python:state_ops tensorflow dep2,
159136
# python:summary tensorflow dep2,
160-
# python:variables tensorflow dep2,
161137
],
162138
)
163139

@@ -168,15 +144,7 @@ py_library(
168144
visibility = ["//visibility:public"],
169145
deps = [
170146
# numpy dep1,
171-
# python:array_ops tensorflow dep2,
172-
# python:constant_op tensorflow dep2,
173-
# python:control_flow_ops tensorflow dep2,
174-
# python:dtypes tensorflow dep2,
175-
# python:framework_ops tensorflow dep2,
176-
# python:init_ops tensorflow dep2,
177-
# python:nn_ops tensorflow dep2,
178-
# python:state_ops tensorflow dep2,
179-
# python:variable_scope tensorflow dep2,
147+
# tensorflow dep1,
180148
],
181149
)
182150

@@ -259,7 +227,6 @@ py_test(
259227
# absl/testing:parameterized dep1,
260228
# numpy dep1,
261229
# tensorflow dep1,
262-
# python/keras tensorflow dep2,
263230
"//tensorflow_model_optimization/python/core/keras:test_utils",
264231
],
265232
)
@@ -300,6 +267,8 @@ py_test(
300267
":pruning_impl",
301268
":pruning_schedule",
302269
# numpy dep1,
270+
# python/keras tensorflow dep2,
271+
"//tensorflow_model_optimization/python/core/keras:test_utils",
303272
],
304273
)
305274

tensorflow_model_optimization/python/core/sparsity/keras/prune.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,14 @@
1515
# pylint: disable=protected-access,missing-docstring,unused-argument
1616
"""Entry point for pruning models during training."""
1717

18-
from tensorflow.python import keras
19-
from tensorflow.python.keras.engine.input_layer import InputLayer
20-
from tensorflow.python.keras.utils.generic_utils import custom_object_scope
18+
import tensorflow as tf
19+
2120
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule as pruning_sched
2221
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper
2322

23+
keras = tf.keras
24+
custom_object_scope = tf.keras.utils.custom_object_scope
25+
2426

2527
def prune_scope():
2628
"""Provides a scope in which Pruned layers and models can be deserialized.
@@ -125,7 +127,7 @@ def _prune_list(layers, **params):
125127
# No need to wrap the input layer either.
126128
if isinstance(layer, pruning_wrapper.PruneLowMagnitude):
127129
wrapped_layers.append(layer)
128-
elif isinstance(layer, InputLayer):
130+
elif isinstance(layer, keras.layers.InputLayer):
129131
# TODO(yunluli): Replace with a clone function in keras.
130132
wrapped_layers.append(layer.__class__.from_config(layer.get_config()))
131133
else:

tensorflow_model_optimization/python/core/sparsity/keras/prune_distributed_test.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,31 +17,28 @@
1717
import tempfile
1818
from absl.testing import parameterized
1919
import numpy as np
20+
import tensorflow as tf
2021

21-
from tensorflow.python import keras
22-
from tensorflow.python.distribute import collective_all_reduce_strategy
23-
from tensorflow.python.distribute import mirrored_strategy
24-
from tensorflow.python.distribute import one_device_strategy
25-
from tensorflow.python.keras.utils import np_utils
26-
from tensorflow.python.platform import test
2722
from tensorflow_model_optimization.python.core.keras import test_utils as keras_test_utils
2823
from tensorflow_model_optimization.python.core.sparsity.keras import prune
2924
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_callbacks
3025
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule
3126
from tensorflow_model_optimization.python.core.sparsity.keras import test_utils
3227

28+
keras = tf.keras
29+
3330

3431
def _distribution_strategies():
3532
return [
36-
collective_all_reduce_strategy.CollectiveAllReduceStrategy(),
37-
mirrored_strategy.MirroredStrategy(),
33+
tf.distribute.experimental.MultiWorkerMirroredStrategy(),
34+
tf.distribute.MirroredStrategy(),
3835
# TODO(pulkitb): Add parameter_server
39-
# parameter_server_strategy.ParameterServerStrategy(),
40-
one_device_strategy.OneDeviceStrategy('/cpu:0'),
36+
# tf.distribute.experimental.ParameterServerStrategy,
37+
tf.distribute.OneDeviceStrategy('/cpu:0'),
4138
]
4239

4340

44-
class PruneDistributedTest(test.TestCase, parameterized.TestCase):
41+
class PruneDistributedTest(tf.test.TestCase, parameterized.TestCase):
4542

4643
def setUp(self):
4744
super(PruneDistributedTest, self).setUp()
@@ -67,8 +64,7 @@ def testPrunesSimpleDenseModel(self, distribution):
6764
# Simple unpruned model. No sparsity.
6865
model.fit(
6966
np.random.rand(20, 10),
70-
np_utils.to_categorical(
71-
np.random.randint(5, size=(20, 1)), 5),
67+
keras.utils.to_categorical(np.random.randint(5, size=(20, 1)), 5),
7268
epochs=2,
7369
callbacks=[pruning_callbacks.UpdatePruningStep()],
7470
batch_size=20)
@@ -85,4 +81,4 @@ def testPrunesSimpleDenseModel(self, distribution):
8581

8682

8783
if __name__ == '__main__':
88-
test.main()
84+
tf.test.main()

0 commit comments

Comments
 (0)