Skip to content

Commit 94f8988

Browse files
alanchiaotensorflower-gardener
authored andcommitted
Changes to pruning to make code compatible with both TF 1.X and TF 2.0 pip
packages by relying on tf.compat.v1. This is to prepare for tf-nightly switching to 2.0 by default. Switch some imports to using tf. API, including one that is causing a test to fail on TF 1.14. Note that tf.compat.v1 is only available in TF 1.14+, which matches what tf-mot supports. PiperOrigin-RevId: 265127228
1 parent 7ecc78c commit 94f8988

File tree

3 files changed

+35
-41
lines changed

3 files changed

+35
-41
lines changed

tensorflow_model_optimization/python/core/sparsity/keras/prune_test.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,12 @@
1818
from absl.testing import parameterized
1919
import numpy as np
2020

21-
from tensorflow.python import keras
22-
from tensorflow.python.framework import errors_impl
21+
import tensorflow.compat.v1 as tf
22+
keras = tf.keras
23+
errors_impl = tf.errors
24+
layers = keras.layers
25+
test = tf.test
2326
from tensorflow.python.keras import keras_parameterized
24-
from tensorflow.python.keras import layers
25-
from tensorflow.python.platform import test
2627
from tensorflow_model_optimization.python.core.sparsity.keras import prunable_layer
2728
from tensorflow_model_optimization.python.core.sparsity.keras import prune
2829
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_callbacks
@@ -349,4 +350,5 @@ def testStripPruningFunctionalModel(self):
349350

350351

351352
if __name__ == '__main__':
353+
tf.disable_v2_behavior()
352354
test.main()

tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl_test.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,14 @@
1919
from __future__ import print_function
2020
# import g3
2121
import numpy as np
22-
from tensorflow.python.eager import context
23-
from tensorflow.python.framework import constant_op
24-
from tensorflow.python.framework import dtypes
25-
from tensorflow.python.framework import ops
22+
23+
import tensorflow.compat.v1 as tf
24+
# TODO(tf-mot): when migrating to 2.0, K.get_session() no longer exists.
25+
K = tf.keras.backend
26+
dtypes = tf.dtypes
27+
test = tf.test
28+
2629
from tensorflow.python.framework import test_util as tf_test_util
27-
from tensorflow.python.keras import backend as K
28-
from tensorflow.python.ops import math_ops
29-
from tensorflow.python.ops import partitioned_variables
30-
from tensorflow.python.ops import state_ops
31-
from tensorflow.python.ops import variable_scope
32-
from tensorflow.python.ops import variables
33-
from tensorflow.python.platform import test
3430
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_impl
3531
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule
3632
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_utils
@@ -66,7 +62,7 @@ def testUpdateSingleMask(self):
6662
mask_before_pruning = K.get_value(mask)
6763
self.assertAllEqual(np.count_nonzero(mask_before_pruning), 100)
6864

69-
if context.executing_eagerly():
65+
if tf.executing_eagerly():
7066
p.conditional_mask_update()
7167
else:
7268
K.get_session().run(p.conditional_mask_update())
@@ -121,7 +117,7 @@ def testBlockMaskingAvg(self):
121117
def testBlockMaskingMax(self):
122118
block_size = (2, 2)
123119
block_pooling_type = "MAX"
124-
weight = constant_op.constant([[0.1, 0.0, 0.2, 0.0], [0.0, -0.1, 0.0, -0.2],
120+
weight = tf.constant([[0.1, 0.0, 0.2, 0.0], [0.0, -0.1, 0.0, -0.2],
125121
[0.3, 0.0, 0.4, 0.0], [0.0, -0.3, 0.0,
126122
-0.4]])
127123
expected_mask = [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0],
@@ -133,7 +129,7 @@ def testBlockMaskingWithHigherDimensionsRaisesError(self):
133129
block_size = (2, 2)
134130
block_pooling_type = "AVG"
135131
# Weights as in testBlockMasking, but with one extra dimension.
136-
weight = constant_op.constant([[[0.1, 0.1, 0.2, 0.2], [0.1, 0.1, 0.2, 0.2],
132+
weight = tf.constant([[[0.1, 0.1, 0.2, 0.2], [0.1, 0.1, 0.2, 0.2],
137133
[0.3, 0.3, 0.4, 0.4], [0.3, 0.3, 0.4,
138134
0.4]]])
139135
expected_mask = [[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0],
@@ -149,9 +145,9 @@ def testConditionalMaskUpdate(self):
149145
threshold = K.zeros([])
150146

151147
def linear_sparsity(step):
152-
sparsity_val = ops.convert_to_tensor(
148+
sparsity_val = tf.convert_to_tensor(
153149
[0.0, 0.1, 0.1, 0.3, 0.3, 0.5, 0.5, 0.5, 0.5, 0.5])
154-
return ops.convert_to_tensor(True), sparsity_val[step]
150+
return tf.convert_to_tensor(True), sparsity_val[step]
155151

156152
# Set up pruning
157153
p = pruning_impl.Pruning(
@@ -163,14 +159,14 @@ def linear_sparsity(step):
163159

164160
non_zero_count = []
165161
for _ in range(10):
166-
if context.executing_eagerly():
162+
if tf.executing_eagerly():
167163
p.conditional_mask_update()
168164
p.weight_mask_op()
169-
state_ops.assign_add(self.global_step, 1)
165+
tf.assign_add(self.global_step, 1)
170166
else:
171167
K.get_session().run(p.conditional_mask_update())
172168
K.get_session().run(p.weight_mask_op())
173-
K.get_session().run(state_ops.assign_add(self.global_step, 1))
169+
K.get_session().run(tf.assign_add(self.global_step, 1))
174170

175171
non_zero_count.append(np.count_nonzero(K.get_value(weight)))
176172

@@ -180,4 +176,5 @@ def linear_sparsity(step):
180176

181177

182178
if __name__ == "__main__":
179+
tf.disable_v2_behavior()
183180
test.main()

tensorflow_model_optimization/python/core/sparsity/keras/pruning_utils_test.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,26 +21,20 @@
2121
# import g3
2222
from absl.testing import parameterized
2323

24-
from tensorflow.python.ops import array_ops
25-
from tensorflow.python.ops import nn_ops
26-
from tensorflow.python.ops import random_ops
27-
from tensorflow.python.ops import variable_scope
28-
from tensorflow.python.ops import variables
29-
from tensorflow.python.platform import test
24+
import tensorflow.compat.v1 as tf
3025
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_utils
3126

32-
3327
@parameterized.named_parameters(
3428
("1x1", [1, 1]), ("4x4", [4, 4]), ("6x6", [6, 6]), ("1x4", [1, 4]),
3529
("4x1", [4, 1]), ("1x8", [1, 8]), ("8x1", [8, 1]))
36-
class PruningUtilsParameterizedTest(test.TestCase, parameterized.TestCase):
30+
class PruningUtilsParameterizedTest(tf.test.TestCase, parameterized.TestCase):
3731

3832
def _compare_pooling_methods(self, weights, pooling_kwargs):
3933
with self.cached_session():
40-
variables.global_variables_initializer().run()
41-
pooled_weights_tf = array_ops.squeeze(
42-
nn_ops.pool(
43-
array_ops.reshape(
34+
tf.global_variables_initializer().run()
35+
pooled_weights_tf = tf.squeeze(
36+
tf.nn.pool(
37+
tf.reshape(
4438
weights,
4539
[1, weights.get_shape()[0],
4640
weights.get_shape()[1], 1]), **pooling_kwargs))
@@ -51,16 +45,16 @@ def _compare_pooling_methods(self, weights, pooling_kwargs):
5145

5246
def _compare_expand_tensor_with_kronecker_product(self, tensor, block_dim):
5347
with self.cached_session() as session:
54-
variables.global_variables_initializer().run()
48+
tf.global_variables_initializer().run()
5549
expanded_tensor = pruning_utils.expand_tensor(tensor, block_dim)
5650
kronecker_product = pruning_utils.kronecker_product(
57-
tensor, array_ops.ones(block_dim))
51+
tensor, tf.ones(block_dim))
5852
expanded_tensor_val, kronecker_product_val = session.run(
5953
[expanded_tensor, kronecker_product])
6054
self.assertAllEqual(expanded_tensor_val, kronecker_product_val)
6155

6256
def testFactorizedAvgPool(self, window_shape):
63-
weights = variable_scope.get_variable("weights", shape=[1024, 2048])
57+
weights = tf.get_variable("weights", shape=[1024, 2048])
6458
pooling_kwargs = {
6559
"window_shape": window_shape,
6660
"pooling_type": "AVG",
@@ -70,7 +64,7 @@ def testFactorizedAvgPool(self, window_shape):
7064
self._compare_pooling_methods(weights, pooling_kwargs)
7165

7266
def testFactorizedMaxPool(self, window_shape):
73-
weights = variable_scope.get_variable("weights", shape=[1024, 2048])
67+
weights = tf.get_variable("weights", shape=[1024, 2048])
7468
pooling_kwargs = {
7569
"window_shape": window_shape,
7670
"pooling_type": "MAX",
@@ -80,9 +74,10 @@ def testFactorizedMaxPool(self, window_shape):
8074
self._compare_pooling_methods(weights, pooling_kwargs)
8175

8276
def testExpandTensor(self, block_dim):
83-
weights = random_ops.random_normal(shape=[1024, 512])
77+
weights = tf.random.normal(shape=[1024, 512])
8478
self._compare_expand_tensor_with_kronecker_product(weights, block_dim)
8579

8680

8781
if __name__ == "__main__":
88-
test.main()
82+
tf.disable_v2_behavior()
83+
tf.test.main()

0 commit comments

Comments
 (0)