Skip to content

Commit dbae704

Browse files
alanchiaotensorflower-gardener
authored andcommitted
Migrate pruning assign to 2.X/1.X public APIs.
PiperOrigin-RevId: 286029751
1 parent ff464f9 commit dbae704

File tree

4 files changed

+48
-30
lines changed

4 files changed

+48
-30
lines changed

tensorflow_model_optimization/python/core/sparsity/keras/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,6 @@ py_library(
132132
deps = [
133133
":pruning_utils",
134134
# tensorflow dep1,
135-
# python:state_ops tensorflow dep2,
136135
# python:summary tensorflow dep2,
137136
],
138137
)

tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl.py

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,11 @@
2020

2121
import tensorflow as tf
2222

23-
# b/(139939526): update assign ops to v2 API.
24-
from tensorflow.python.ops import state_ops
2523
from tensorflow.python.ops import summary_ops_v2
2624
from tensorflow.python.summary import summary as summary_ops_v1
2725
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_utils
2826

27+
2928
class Pruning(object):
3029
"""Implementation of magnitude-based weight pruning."""
3130

@@ -54,6 +53,13 @@ def __init__(self, training_step_fn, pruning_vars, pruning_schedule,
5453

5554
self._validate_block()
5655

56+
@staticmethod
57+
def _assign(ref, value):
58+
if tf.__version__[0] == '1':
59+
return tf.assign(ref, value)
60+
else:
61+
return ref.assign(value)
62+
5763
def _validate_block(self):
5864
if self._block_size != [1, 1]:
5965
for weight, _, _ in self._pruning_vars:
@@ -144,8 +150,15 @@ def _maybe_update_block_mask(self, weights):
144150
squeezed_weights.get_shape()[1]])
145151
return new_threshold, tf.reshape(sliced_mask, tf.shape(weights))
146152

147-
def _get_weight_assign_ops(self):
148-
"""Gather the assign ops for assigning weights<=weights*mask."""
153+
def _weight_assign_objs(self):
154+
"""Gather the assign objs for assigning weights<=weights*mask.
155+
156+
The objs are ops for graph execution and tensors for eager
157+
execution.
158+
159+
Returns:
160+
group of objs for weight assignment.
161+
"""
149162

150163
def update_fn(distribution, values_and_vars):
151164
# TODO(yunluli): Need this ReduceOp because the weight is created by the
@@ -158,34 +171,34 @@ def update_fn(distribution, values_and_vars):
158171
values_and_vars = zip(reduced_values, var_list)
159172

160173
def update_var(variable, reduced_value):
161-
return state_ops.assign(variable, reduced_value)
174+
return self._assign(variable, reduced_value)
162175

163-
update_ops = []
176+
update_objs = []
164177
for value, var in values_and_vars:
165-
update_ops.append(
178+
update_objs.append(
166179
distribution.extended.update(var, update_var, args=(value,)))
167180

168-
return tf.group(update_ops)
181+
return tf.group(update_objs)
169182

170-
assign_ops = []
183+
assign_objs = []
171184

172185
if tf.distribute.get_replica_context():
173186
values_and_vars = []
174187
for weight, mask, _ in self._pruning_vars:
175188
masked_weight = tf.math.multiply(weight, mask)
176189
values_and_vars.append((masked_weight, weight))
177190
if values_and_vars:
178-
assign_ops.append(tf.distribute.get_replica_context().merge_call(
191+
assign_objs.append(tf.distribute.get_replica_context().merge_call(
179192
update_fn, args=(values_and_vars,)))
180193
else:
181194
for weight, mask, _ in self._pruning_vars:
182195
masked_weight = tf.math.multiply(weight, mask)
183-
assign_ops.append(state_ops.assign(weight, masked_weight))
196+
assign_objs.append(self._assign(weight, masked_weight))
184197

185-
return assign_ops
198+
return assign_objs
186199

187200
def weight_mask_op(self):
188-
return tf.group(self._get_weight_assign_ops())
201+
return tf.group(self._weight_assign_objs())
189202

190203
def conditional_mask_update(self):
191204
"""Returns an op to updates masks as per the pruning schedule."""
@@ -200,35 +213,39 @@ def mask_update():
200213
"""Updates mask without distribution strategy."""
201214

202215
def update():
203-
assign_ops = []
216+
assign_objs = []
204217

205218
for weight, mask, threshold in self._pruning_vars:
206219
new_threshold, new_mask = self._maybe_update_block_mask(weight)
207-
assign_ops.append(state_ops.assign(threshold, new_threshold))
208-
assign_ops.append(state_ops.assign(mask, new_mask))
220+
assign_objs.append(self._assign(threshold, new_threshold))
221+
assign_objs.append(self._assign(mask, new_mask))
209222

210-
return tf.group(assign_ops)
223+
return tf.group(assign_objs)
211224

212225
return tf.cond(maybe_update_masks(), update, no_update)
213226

214227
def mask_update_distributed(distribution):
215228
"""Updates mask with distribution strategy."""
216229

217230
def update(var, value):
218-
return state_ops.assign(var, value)
231+
return self._assign(var, value)
219232

220233
def update_distributed():
221-
"""Gather distributed update ops."""
222-
assign_ops = []
234+
"""Gather distributed update objs.
235+
236+
The objs are ops for graph execution and tensors for eager
237+
execution.
238+
"""
239+
assign_objs = []
223240

224241
for weight, mask, threshold in self._pruning_vars:
225242
new_threshold, new_mask = self._maybe_update_block_mask(weight)
226-
assign_ops.append(
243+
assign_objs.append(
227244
distribution.extended.update(mask, update, (new_mask,)))
228-
assign_ops.append(
245+
assign_objs.append(
229246
distribution.extended.update(threshold, update, (new_threshold,)))
230247

231-
return tf.group(assign_ops)
248+
return tf.group(assign_objs)
232249

233250
return tf.cond(maybe_update_masks(), update_distributed, no_update)
234251

tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl_test.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def setUp(self):
5959
# setUp() lies outside of the "eager scope" that wraps the test cases
6060
# themselves, resulting in initializing graph tensors instead of eager
6161
# tensors when testing eager execution.
62-
def initialize_training_step_fn_and_all_variables(self):
62+
def initialize(self):
6363
self.global_step = tf.Variable(
6464
tf.zeros([], dtype=dtypes.int32),
6565
dtype=dtypes.int32,
@@ -81,7 +81,7 @@ def testUpdateSingleMask(self):
8181
dtype=weight_dtype)
8282
threshold = tf.Variable(
8383
tf.zeros([], dtype=weight_dtype), name="threshold", dtype=weight_dtype)
84-
self.initialize_training_step_fn_and_all_variables()
84+
self.initialize()
8585

8686
p = pruning_impl.Pruning(
8787
pruning_vars=[(weight, mask, threshold)],
@@ -102,7 +102,7 @@ def testUpdateSingleMask(self):
102102
self.assertAllEqual(np.count_nonzero(mask_after_pruning), 50)
103103

104104
def testConstructsMaskAndThresholdCorrectly(self):
105-
self.initialize_training_step_fn_and_all_variables()
105+
self.initialize()
106106
p = pruning_impl.Pruning(
107107
lambda: 0, None,
108108
# Sparsity math often returns values with small tolerances.
@@ -125,7 +125,7 @@ def _blockMasking(self, block_size, block_pooling_type, weight,
125125
dtype=weight.dtype)
126126
threshold = tf.Variable(
127127
tf.zeros([], dtype=weight.dtype), name="threshold", dtype=weight.dtype)
128-
self.initialize_training_step_fn_and_all_variables()
128+
self.initialize()
129129

130130
# Set up pruning
131131
p = pruning_impl.Pruning(
@@ -163,7 +163,7 @@ def testBlockMaskingMax(self):
163163
self._blockMasking(block_size, block_pooling_type, weight, expected_mask)
164164

165165
def testBlockMaskingWithHigherDimensionsRaisesError(self):
166-
self.initialize_training_step_fn_and_all_variables()
166+
self.initialize()
167167
block_size = (2, 2)
168168
block_pooling_type = "AVG"
169169
# Weights as in testBlockMasking, but with one extra dimension.
@@ -186,7 +186,7 @@ def testConditionalMaskUpdate(self):
186186
dtype=weight_dtype)
187187
threshold = tf.Variable(
188188
tf.zeros([], dtype=weight_dtype), name="threshold", dtype=weight_dtype)
189-
self.initialize_training_step_fn_and_all_variables()
189+
self.initialize()
190190

191191
def linear_sparsity(step):
192192
sparsity_val = tf.convert_to_tensor(

tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,8 @@ def no_op():
236236
# Always execute the op that performs weights = weights * mask
237237
# Relies on UpdatePruningStep callback to ensure the weights
238238
# are sparse after the final backpropagation.
239+
#
240+
# self.add_update does nothing during eager execution.
239241
self.add_update(self.pruning_obj.weight_mask_op())
240242

241243
return self.layer.call(inputs)

0 commit comments

Comments
 (0)