Skip to content

Commit a0bf7a4

Browse files
liyunlu0618alanchiao
authored andcommitted
Fix pruning with distribution strategy.
PiperOrigin-RevId: 246366931
1 parent 3ebe133 commit a0bf7a4

File tree

6 files changed

+95
-185
lines changed

6 files changed

+95
-185
lines changed

tensorflow_model_optimization/python/core/sparsity/keras/BUILD

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,11 +148,10 @@ py_library(
148148
# python:array_ops tensorflow dep2,
149149
# python:control_flow_ops tensorflow dep2,
150150
# python:dtypes tensorflow dep2,
151-
# python:framework tensorflow dep2,
152151
# python:framework_ops tensorflow dep2,
153152
# python:math_ops tensorflow dep2,
154153
# python:nn_ops tensorflow dep2,
155-
# python:platform tensorflow dep2,
154+
# python:state_ops tensorflow dep2,
156155
# python:summary tensorflow dep2,
157156
# python:variables tensorflow dep2,
158157
],

tensorflow_model_optimization/python/core/sparsity/keras/prune_distributed_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# ==============================================================================
1515
"""Distributed pruning test."""
1616

17+
import tempfile
1718
from absl.testing import parameterized
1819
import numpy as np
1920

@@ -71,6 +72,14 @@ def testPrunesSimpleDenseModel(self, distribution):
7172
model.predict(np.random.rand(20, 10))
7273
test_utils.assert_model_sparsity(self, 0.5, model)
7374

75+
_, keras_file = tempfile.mkstemp('.h5')
76+
keras.models.save_model(model, keras_file)
77+
78+
with prune.prune_scope():
79+
loaded_model = keras.models.load_model(keras_file)
80+
81+
test_utils.assert_model_sparsity(self, 0.5, loaded_model)
82+
7483

7584
if __name__ == '__main__':
7685
test.main()

tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl.py

Lines changed: 81 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,12 @@
2626
from tensorflow.python.ops import control_flow_ops
2727
from tensorflow.python.ops import math_ops
2828
from tensorflow.python.ops import nn_ops
29+
from tensorflow.python.ops import state_ops
2930
from tensorflow.python.ops import summary_ops_v2
3031
from tensorflow.python.ops import variables
3132
from tensorflow.python.summary import summary as summary_ops_v1
3233
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_utils
3334

34-
3535
class Pruning(object):
3636
"""Implementation of magnitude-based weight pruning."""
3737

@@ -55,15 +55,9 @@ def __init__(self, training_step_fn, pruning_vars, pruning_schedule,
5555
self._block_pooling_type = block_pooling_type
5656
self._validate_block()
5757

58-
# List of tensorflow assignments ops for new masks and thresholds
59-
self._assign_ops = []
60-
6158
# Training step
6259
self._step_fn = training_step_fn
6360

64-
# List of tensorflow assignment ops for the weights
65-
self._weight_assign_ops = []
66-
6761
self._validate_block()
6862

6963
def _validate_block(self):
@@ -73,9 +67,6 @@ def _validate_block(self):
7367
raise ValueError('Block Sparsity can only be used for layers which '
7468
'have 2-dimensional weights.')
7569

76-
def get_weight_sparsity(self):
77-
return [math_ops.reduce_mean(weight) for weight, _, _ in self._pruning_vars]
78-
7970
def _update_mask(self, weights):
8071
"""Updates the mask for a given weight tensor.
8172
@@ -161,69 +152,99 @@ def _maybe_update_block_mask(self, weights):
161152
return new_threshold, array_ops.reshape(sliced_mask,
162153
array_ops.shape(weights))
163154

164-
def _get_assign_ops(self):
165-
"""Gather the assign ops for assigning updated masks and threshold."""
166-
# Make sure the assignment ops have not already been added to the list
167-
if self._assign_ops:
168-
raise ValueError(
169-
'Assign op list not empty. _get_assign_ops() called twice?')
170-
171-
for weight, mask, threshold in self._pruning_vars:
172-
is_partitioned = isinstance(weight, variables.PartitionedVariable)
173-
weight_as_tensor = weight
174-
if is_partitioned:
175-
weight_as_tensor = weight.as_tensor()
176-
177-
new_threshold, new_mask = self._maybe_update_block_mask(weight_as_tensor)
178-
self._assign_ops.append(
179-
pruning_utils.variable_assign(threshold, new_threshold))
180-
181-
self._assign_ops.append(
182-
pruning_utils.partitioned_variable_assign(mask, new_mask)
183-
if is_partitioned else pruning_utils.variable_assign(mask, new_mask))
184-
185155
def _get_weight_assign_ops(self):
186156
"""Gather the assign ops for assigning weights<=weights*mask."""
187-
if self._weight_assign_ops:
188-
raise ValueError(
189-
'Assign op list not empty. _get_weight_assign_ops() called twice?')
190-
191-
for weight, mask, _ in self._pruning_vars:
192-
is_partitioned = isinstance(weight, variables.PartitionedVariable)
193-
masked_weight = math_ops.multiply(weight, mask)
194-
self._weight_assign_ops.append(
195-
pruning_utils.partitioned_variable_assign(weight, masked_weight)
196-
if is_partitioned else pruning_utils
197-
.variable_assign(weight, masked_weight))
198-
199-
def weight_mask_op(self):
200-
if tf.executing_eagerly() or not self._weight_assign_ops:
201-
self._weight_assign_ops = []
202-
self._get_weight_assign_ops()
203-
204-
with ops.control_dependencies(self._weight_assign_ops):
205-
return control_flow_ops.no_op('mask_weights')
206157

207-
def mask_update_op(self):
208-
self._assign_ops = []
209-
self._get_assign_ops()
158+
def update_fn(distribution, values_and_vars):
159+
# TODO(yunluli): Need this ReduceOp because the weight is created by the
160+
# layer wrapped, so we don't have control of its aggregation policy. May
161+
# be able to optimize this when distribution strategy supports easier
162+
# update to mirrored variables in replica context.
163+
reduced_values = distribution.extended.batch_reduce_to(
164+
tf.distribute.ReduceOp.MEAN, values_and_vars)
165+
var_list = [v for _, v in values_and_vars]
166+
values_and_vars = zip(reduced_values, var_list)
167+
168+
def update_var(variable, reduced_value):
169+
return state_ops.assign(variable, reduced_value)
170+
171+
update_ops = []
172+
for value, var in values_and_vars:
173+
update_ops.append(
174+
distribution.extended.update(var, update_var, args=(value,)))
175+
176+
return control_flow_ops.group(update_ops)
177+
178+
assign_ops = []
179+
180+
if tf.distribute.get_replica_context():
181+
values_and_vars = []
182+
for weight, mask, _ in self._pruning_vars:
183+
masked_weight = math_ops.multiply(weight, mask)
184+
values_and_vars.append((masked_weight, weight))
185+
assign_ops.append(tf.distribute.get_replica_context().merge_call(
186+
update_fn, args=(values_and_vars,)))
187+
else:
188+
for weight, mask, _ in self._pruning_vars:
189+
masked_weight = math_ops.multiply(weight, mask)
190+
assign_ops.append(state_ops.assign(weight, masked_weight))
191+
192+
return assign_ops
210193

211-
with ops.control_dependencies(self._assign_ops):
212-
return control_flow_ops.no_op('mask_update')
194+
def weight_mask_op(self):
195+
return control_flow_ops.group(self._get_weight_assign_ops())
213196

214197
def conditional_mask_update(self):
215198
"""Returns an op to updates masks as per the pruning schedule."""
216199

217200
def maybe_update_masks():
218201
return self._pruning_schedule(self._step_fn())[0]
219202

220-
def mask_update_op():
221-
return self.mask_update_op()
222-
223-
def no_op():
203+
def no_update():
224204
return control_flow_ops.no_op()
225205

226-
return control_flow_ops.cond(maybe_update_masks(), mask_update_op, no_op)
206+
def mask_update():
207+
"""Updates mask without distribution strategy."""
208+
209+
def update():
210+
assign_ops = []
211+
212+
for weight, mask, threshold in self._pruning_vars:
213+
new_threshold, new_mask = self._maybe_update_block_mask(weight)
214+
assign_ops.append(state_ops.assign(threshold, new_threshold))
215+
assign_ops.append(state_ops.assign(mask, new_mask))
216+
217+
return control_flow_ops.group(assign_ops)
218+
219+
return control_flow_ops.cond(maybe_update_masks(), update, no_update)
220+
221+
def mask_update_distributed(distribution):
222+
"""Updates mask with distribution strategy."""
223+
224+
def update(var, value):
225+
return state_ops.assign(var, value)
226+
227+
def update_distributed():
228+
"""Gather distributed update ops."""
229+
assign_ops = []
230+
231+
for weight, mask, threshold in self._pruning_vars:
232+
new_threshold, new_mask = self._maybe_update_block_mask(weight)
233+
assign_ops.append(
234+
distribution.extended.update(mask, update, (new_mask,)))
235+
assign_ops.append(
236+
distribution.extended.update(threshold, update, (new_threshold,)))
237+
238+
return control_flow_ops.group(assign_ops)
239+
240+
return control_flow_ops.cond(maybe_update_masks(), update_distributed,
241+
no_update)
242+
243+
if tf.distribute.get_replica_context():
244+
return tf.distribute.get_replica_context().merge_call(
245+
mask_update_distributed)
246+
else:
247+
return mask_update()
227248

228249
def add_pruning_summaries(self):
229250
"""Adds summaries of weight sparsities and thresholds."""

tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl_test.py

Lines changed: 2 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,9 @@ def testUpdateSingleMask(self):
6767
self.assertAllEqual(np.count_nonzero(mask_before_pruning), 100)
6868

6969
if context.executing_eagerly():
70-
p.mask_update_op()
70+
p.conditional_mask_update()
7171
else:
72-
K.get_session().run(p.mask_update_op())
72+
K.get_session().run(p.conditional_mask_update())
7373

7474
mask_after_pruning = K.get_value(mask)
7575
self.assertAllEqual(np.count_nonzero(mask_after_pruning), 50)
@@ -143,31 +143,6 @@ def testBlockMaskingWithHigherDimensionsRaisesError(self):
143143
with self.assertRaises(ValueError):
144144
self._blockMasking(block_size, block_pooling_type, weight, expected_mask)
145145

146-
def testPartitionedVariableMasking(self):
147-
partitioner = partitioned_variables.variable_axis_size_partitioner(40)
148-
with self.cached_session():
149-
with variable_scope.variable_scope("", partitioner=partitioner):
150-
weight = variable_scope.get_variable(
151-
"weights", initializer=math_ops.linspace(1.0, 100.0, 100))
152-
mask = pruning_utils.mask_variable(weight)
153-
threshold = pruning_utils.threshold_variable(weight)
154-
155-
p = pruning_impl.Pruning(
156-
pruning_vars=[(weight, mask, threshold)],
157-
training_step_fn=self.training_step_fn,
158-
pruning_schedule=self.constant_sparsity,
159-
block_size=self.block_size,
160-
block_pooling_type=self.block_pooling_type)
161-
162-
if context.executing_eagerly():
163-
p.mask_update_op()
164-
else:
165-
variables.global_variables_initializer().run()
166-
K.get_session().run(p.mask_update_op())
167-
168-
mask_after_pruning = K.get_value(mask.as_tensor())
169-
self.assertAllEqual(np.count_nonzero(mask_after_pruning), 50)
170-
171146
def testConditionalMaskUpdate(self):
172147
weight = K.variable(np.linspace(1.0, 100.0, 100), name="weights")
173148
mask = K.ones(weight.get_shape())

tensorflow_model_optimization/python/core/sparsity/keras/pruning_utils.py

Lines changed: 0 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -34,52 +34,6 @@
3434
from tensorflow.python.ops import variable_scope
3535

3636

37-
def mask_variable(var, scope=''):
38-
"""Create a mask for the weights.
39-
40-
This function adds a variable 'mask' to the graph.
41-
42-
Args:
43-
var: the weight variable that needs to be masked
44-
scope: The variable scope of the variable var
45-
46-
Returns:
47-
the mask variable of the same size and shape as var, initialized to all 1s.
48-
"""
49-
with variable_scope.variable_scope(scope):
50-
# TODO(suyoggupta): Remove variable_scope dependency
51-
mask = variable_scope.get_variable(
52-
'mask',
53-
var.get_shape(),
54-
initializer=init_ops.ones_initializer(),
55-
trainable=False,
56-
dtype=var.dtype)
57-
return mask
58-
59-
60-
def threshold_variable(var, scope=''):
61-
"""Create a scalar threshold for the weights.
62-
63-
This function adds a variable
64-
'threshold' to the graph.
65-
66-
Args:
67-
var: The weight variable that needs to be masked
68-
scope: The variable scope of the variable var
69-
70-
Returns:
71-
A scalar threshold variable initialized to 0.
72-
"""
73-
with variable_scope.variable_scope(scope):
74-
# TODO(suyoggupta): Remove variable_scope dependency
75-
threshold = variable_scope.get_variable(
76-
'threshold', [],
77-
initializer=init_ops.zeros_initializer(),
78-
trainable=False,
79-
dtype=var.dtype)
80-
return threshold
81-
82-
8337
def kronecker_product(mat1, mat2):
8438
"""Computes the Kronecker product of two matrices mat1 and mat2.
8539
@@ -97,7 +51,6 @@ def kronecker_product(mat1, mat2):
9751
mat2_rsh = array_ops.reshape(mat2, [1, m2, 1, n2])
9852
return array_ops.reshape(mat1_rsh * mat2_rsh, [m1 * m2, n1 * n2])
9953

100-
10154
def expand_tensor(tensor, block_size):
10255
"""Expands a 2D tensor by replicating the tensor values.
10356
@@ -213,50 +166,3 @@ def factorized_pool(input_tensor,
213166

214167
return array_ops.squeeze(
215168
array_ops.transpose(width_pooling, perm=[0, 1, 3, 2]))
216-
217-
218-
def determine_partitioned_axis(partitioned_variable):
219-
partitioned_axis = 0
220-
concatenated_variable_shape = partitioned_variable.get_shape()
221-
for partition in partitioned_variable:
222-
partition_shape = partition.get_shape()
223-
maybe_partitioned_axis = np.less(partition_shape,
224-
concatenated_variable_shape)
225-
# Sanity check: make sure number of partitioned axis == 1
226-
if np.count_nonzero(maybe_partitioned_axis) != 1:
227-
raise ValueError('Number of partitioned axes %s not equal to 1' %
228-
np.count_nonzero(maybe_partitioned_axis))
229-
partitioned_axis = np.where(maybe_partitioned_axis)[0][0]
230-
return partitioned_axis
231-
232-
233-
def variable_assign(var, new_value):
234-
return state_ops.assign(var, new_value)
235-
236-
237-
def partitioned_variable_assign(partitioned_var, new_value):
238-
"""Assign op for partitioned variables.
239-
240-
Args:
241-
partitioned_var: A partitioned tensorflow variable
242-
new_value: Value to be assigned to the variable var
243-
244-
Returns:
245-
A tensorflow op that groups the assign ops for each of the variable slices
246-
"""
247-
# Determine which axis was used to partition the variable. Currently
248-
# tensorflow allows partitioning variable only along 1 axis.
249-
axis = 0 if len(partitioned_var) == 1 else determine_partitioned_axis(
250-
partitioned_var)
251-
252-
partition_sizes = np.array(
253-
[partition.get_shape()[axis] for partition in partitioned_var])
254-
new_partitioned_values = array_ops.split(
255-
new_value,
256-
ops.convert_to_tensor(partition_sizes, dtype=dtypes.int32),
257-
axis=axis)
258-
op_list = []
259-
for partition in partitioned_var:
260-
op_list.append(
261-
variable_assign(partition, new_partitioned_values[len(op_list)]))
262-
return control_flow_ops.group(*op_list)

tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,9 +231,9 @@ def no_op():
231231
return control_flow_ops.no_op('no_update')
232232

233233
update_op = tf_utils.smart_cond(training, add_update, no_op)
234-
self.layer.add_update(update_op)
234+
self.add_update(update_op)
235235
# Always execute the op that performs weights = weights * mask
236-
self.layer.add_update(self.pruning_obj.weight_mask_op())
236+
self.add_update(self.pruning_obj.weight_mask_op())
237237

238238
return self.layer.call(inputs)
239239

0 commit comments

Comments
 (0)