26
26
from tensorflow .python .ops import control_flow_ops
27
27
from tensorflow .python .ops import math_ops
28
28
from tensorflow .python .ops import nn_ops
29
+ from tensorflow .python .ops import state_ops
29
30
from tensorflow .python .ops import summary_ops_v2
30
31
from tensorflow .python .ops import variables
31
32
from tensorflow .python .summary import summary as summary_ops_v1
32
33
from tensorflow_model_optimization .python .core .sparsity .keras import pruning_utils
33
34
34
-
35
35
class Pruning (object ):
36
36
"""Implementation of magnitude-based weight pruning."""
37
37
@@ -55,15 +55,9 @@ def __init__(self, training_step_fn, pruning_vars, pruning_schedule,
55
55
self ._block_pooling_type = block_pooling_type
56
56
self ._validate_block ()
57
57
58
- # List of tensorflow assignments ops for new masks and thresholds
59
- self ._assign_ops = []
60
-
61
58
# Training step
62
59
self ._step_fn = training_step_fn
63
60
64
- # List of tensorflow assignment ops for the weights
65
- self ._weight_assign_ops = []
66
-
67
61
self ._validate_block ()
68
62
69
63
def _validate_block (self ):
@@ -73,9 +67,6 @@ def _validate_block(self):
73
67
raise ValueError ('Block Sparsity can only be used for layers which '
74
68
'have 2-dimensional weights.' )
75
69
76
- def get_weight_sparsity (self ):
77
- return [math_ops .reduce_mean (weight ) for weight , _ , _ in self ._pruning_vars ]
78
-
79
70
def _update_mask (self , weights ):
80
71
"""Updates the mask for a given weight tensor.
81
72
@@ -161,69 +152,99 @@ def _maybe_update_block_mask(self, weights):
161
152
return new_threshold , array_ops .reshape (sliced_mask ,
162
153
array_ops .shape (weights ))
163
154
164
- def _get_assign_ops (self ):
165
- """Gather the assign ops for assigning updated masks and threshold."""
166
- # Make sure the assignment ops have not already been added to the list
167
- if self ._assign_ops :
168
- raise ValueError (
169
- 'Assign op list not empty. _get_assign_ops() called twice?' )
170
-
171
- for weight , mask , threshold in self ._pruning_vars :
172
- is_partitioned = isinstance (weight , variables .PartitionedVariable )
173
- weight_as_tensor = weight
174
- if is_partitioned :
175
- weight_as_tensor = weight .as_tensor ()
176
-
177
- new_threshold , new_mask = self ._maybe_update_block_mask (weight_as_tensor )
178
- self ._assign_ops .append (
179
- pruning_utils .variable_assign (threshold , new_threshold ))
180
-
181
- self ._assign_ops .append (
182
- pruning_utils .partitioned_variable_assign (mask , new_mask )
183
- if is_partitioned else pruning_utils .variable_assign (mask , new_mask ))
184
-
185
155
def _get_weight_assign_ops (self ):
186
156
"""Gather the assign ops for assigning weights<=weights*mask."""
187
- if self ._weight_assign_ops :
188
- raise ValueError (
189
- 'Assign op list not empty. _get_weight_assign_ops() called twice?' )
190
-
191
- for weight , mask , _ in self ._pruning_vars :
192
- is_partitioned = isinstance (weight , variables .PartitionedVariable )
193
- masked_weight = math_ops .multiply (weight , mask )
194
- self ._weight_assign_ops .append (
195
- pruning_utils .partitioned_variable_assign (weight , masked_weight )
196
- if is_partitioned else pruning_utils
197
- .variable_assign (weight , masked_weight ))
198
-
199
- def weight_mask_op (self ):
200
- if tf .executing_eagerly () or not self ._weight_assign_ops :
201
- self ._weight_assign_ops = []
202
- self ._get_weight_assign_ops ()
203
-
204
- with ops .control_dependencies (self ._weight_assign_ops ):
205
- return control_flow_ops .no_op ('mask_weights' )
206
157
207
- def mask_update_op (self ):
208
- self ._assign_ops = []
209
- self ._get_assign_ops ()
158
+ def update_fn (distribution , values_and_vars ):
159
+ # TODO(yunluli): Need this ReduceOp because the weight is created by the
160
+ # layer wrapped, so we don't have control of its aggregation policy. May
161
+ # be able to optimize this when distribution strategy supports easier
162
+ # update to mirrored variables in replica context.
163
+ reduced_values = distribution .extended .batch_reduce_to (
164
+ tf .distribute .ReduceOp .MEAN , values_and_vars )
165
+ var_list = [v for _ , v in values_and_vars ]
166
+ values_and_vars = zip (reduced_values , var_list )
167
+
168
+ def update_var (variable , reduced_value ):
169
+ return state_ops .assign (variable , reduced_value )
170
+
171
+ update_ops = []
172
+ for value , var in values_and_vars :
173
+ update_ops .append (
174
+ distribution .extended .update (var , update_var , args = (value ,)))
175
+
176
+ return control_flow_ops .group (update_ops )
177
+
178
+ assign_ops = []
179
+
180
+ if tf .distribute .get_replica_context ():
181
+ values_and_vars = []
182
+ for weight , mask , _ in self ._pruning_vars :
183
+ masked_weight = math_ops .multiply (weight , mask )
184
+ values_and_vars .append ((masked_weight , weight ))
185
+ assign_ops .append (tf .distribute .get_replica_context ().merge_call (
186
+ update_fn , args = (values_and_vars ,)))
187
+ else :
188
+ for weight , mask , _ in self ._pruning_vars :
189
+ masked_weight = math_ops .multiply (weight , mask )
190
+ assign_ops .append (state_ops .assign (weight , masked_weight ))
191
+
192
+ return assign_ops
210
193
211
- with ops . control_dependencies (self . _assign_ops ):
212
- return control_flow_ops .no_op ( 'mask_update' )
194
+ def weight_mask_op (self ):
195
+ return control_flow_ops .group ( self . _get_weight_assign_ops () )
213
196
214
197
def conditional_mask_update (self ):
215
198
"""Returns an op to updates masks as per the pruning schedule."""
216
199
217
200
def maybe_update_masks ():
218
201
return self ._pruning_schedule (self ._step_fn ())[0 ]
219
202
220
- def mask_update_op ():
221
- return self .mask_update_op ()
222
-
223
- def no_op ():
203
+ def no_update ():
224
204
return control_flow_ops .no_op ()
225
205
226
- return control_flow_ops .cond (maybe_update_masks (), mask_update_op , no_op )
206
+ def mask_update ():
207
+ """Updates mask without distribution strategy."""
208
+
209
+ def update ():
210
+ assign_ops = []
211
+
212
+ for weight , mask , threshold in self ._pruning_vars :
213
+ new_threshold , new_mask = self ._maybe_update_block_mask (weight )
214
+ assign_ops .append (state_ops .assign (threshold , new_threshold ))
215
+ assign_ops .append (state_ops .assign (mask , new_mask ))
216
+
217
+ return control_flow_ops .group (assign_ops )
218
+
219
+ return control_flow_ops .cond (maybe_update_masks (), update , no_update )
220
+
221
+ def mask_update_distributed (distribution ):
222
+ """Updates mask with distribution strategy."""
223
+
224
+ def update (var , value ):
225
+ return state_ops .assign (var , value )
226
+
227
+ def update_distributed ():
228
+ """Gather distributed update ops."""
229
+ assign_ops = []
230
+
231
+ for weight , mask , threshold in self ._pruning_vars :
232
+ new_threshold , new_mask = self ._maybe_update_block_mask (weight )
233
+ assign_ops .append (
234
+ distribution .extended .update (mask , update , (new_mask ,)))
235
+ assign_ops .append (
236
+ distribution .extended .update (threshold , update , (new_threshold ,)))
237
+
238
+ return control_flow_ops .group (assign_ops )
239
+
240
+ return control_flow_ops .cond (maybe_update_masks (), update_distributed ,
241
+ no_update )
242
+
243
+ if tf .distribute .get_replica_context ():
244
+ return tf .distribute .get_replica_context ().merge_call (
245
+ mask_update_distributed )
246
+ else :
247
+ return mask_update ()
227
248
228
249
def add_pruning_summaries (self ):
229
250
"""Adds summaries of weight sparsities and thresholds."""
0 commit comments