20
20
21
21
import tensorflow as tf
22
22
23
- # b/(139939526): update assign ops to v2 API.
24
- from tensorflow .python .ops import state_ops
25
23
from tensorflow .python .ops import summary_ops_v2
26
24
from tensorflow .python .summary import summary as summary_ops_v1
27
25
from tensorflow_model_optimization .python .core .sparsity .keras import pruning_utils
28
26
27
+
29
28
class Pruning (object ):
30
29
"""Implementation of magnitude-based weight pruning."""
31
30
@@ -54,6 +53,13 @@ def __init__(self, training_step_fn, pruning_vars, pruning_schedule,
54
53
55
54
self ._validate_block ()
56
55
56
+ @staticmethod
57
+ def _assign (ref , value ):
58
+ if tf .__version__ [0 ] == '1' :
59
+ return tf .assign (ref , value )
60
+ else :
61
+ return ref .assign (value )
62
+
57
63
def _validate_block (self ):
58
64
if self ._block_size != [1 , 1 ]:
59
65
for weight , _ , _ in self ._pruning_vars :
@@ -144,8 +150,15 @@ def _maybe_update_block_mask(self, weights):
144
150
squeezed_weights .get_shape ()[1 ]])
145
151
return new_threshold , tf .reshape (sliced_mask , tf .shape (weights ))
146
152
147
- def _get_weight_assign_ops (self ):
148
- """Gather the assign ops for assigning weights<=weights*mask."""
153
+ def _weight_assign_objs (self ):
154
+ """Gather the assign objs for assigning weights<=weights*mask.
155
+
156
+ The objs are ops for graph execution and tensors for eager
157
+ execution.
158
+
159
+ Returns:
160
+ group of objs for weight assignment.
161
+ """
149
162
150
163
def update_fn (distribution , values_and_vars ):
151
164
# TODO(yunluli): Need this ReduceOp because the weight is created by the
@@ -158,34 +171,34 @@ def update_fn(distribution, values_and_vars):
158
171
values_and_vars = zip (reduced_values , var_list )
159
172
160
173
def update_var (variable , reduced_value ):
161
- return state_ops . assign (variable , reduced_value )
174
+ return self . _assign (variable , reduced_value )
162
175
163
- update_ops = []
176
+ update_objs = []
164
177
for value , var in values_and_vars :
165
- update_ops .append (
178
+ update_objs .append (
166
179
distribution .extended .update (var , update_var , args = (value ,)))
167
180
168
- return tf .group (update_ops )
181
+ return tf .group (update_objs )
169
182
170
- assign_ops = []
183
+ assign_objs = []
171
184
172
185
if tf .distribute .get_replica_context ():
173
186
values_and_vars = []
174
187
for weight , mask , _ in self ._pruning_vars :
175
188
masked_weight = tf .math .multiply (weight , mask )
176
189
values_and_vars .append ((masked_weight , weight ))
177
190
if values_and_vars :
178
- assign_ops .append (tf .distribute .get_replica_context ().merge_call (
191
+ assign_objs .append (tf .distribute .get_replica_context ().merge_call (
179
192
update_fn , args = (values_and_vars ,)))
180
193
else :
181
194
for weight , mask , _ in self ._pruning_vars :
182
195
masked_weight = tf .math .multiply (weight , mask )
183
- assign_ops .append (state_ops . assign (weight , masked_weight ))
196
+ assign_objs .append (self . _assign (weight , masked_weight ))
184
197
185
- return assign_ops
198
+ return assign_objs
186
199
187
200
def weight_mask_op (self ):
188
- return tf .group (self ._get_weight_assign_ops ())
201
+ return tf .group (self ._weight_assign_objs ())
189
202
190
203
def conditional_mask_update (self ):
191
204
"""Returns an op to updates masks as per the pruning schedule."""
@@ -200,35 +213,39 @@ def mask_update():
200
213
"""Updates mask without distribution strategy."""
201
214
202
215
def update ():
203
- assign_ops = []
216
+ assign_objs = []
204
217
205
218
for weight , mask , threshold in self ._pruning_vars :
206
219
new_threshold , new_mask = self ._maybe_update_block_mask (weight )
207
- assign_ops .append (state_ops . assign (threshold , new_threshold ))
208
- assign_ops .append (state_ops . assign (mask , new_mask ))
220
+ assign_objs .append (self . _assign (threshold , new_threshold ))
221
+ assign_objs .append (self . _assign (mask , new_mask ))
209
222
210
- return tf .group (assign_ops )
223
+ return tf .group (assign_objs )
211
224
212
225
return tf .cond (maybe_update_masks (), update , no_update )
213
226
214
227
def mask_update_distributed (distribution ):
215
228
"""Updates mask with distribution strategy."""
216
229
217
230
def update (var , value ):
218
- return state_ops . assign (var , value )
231
+ return self . _assign (var , value )
219
232
220
233
def update_distributed ():
221
- """Gather distributed update ops."""
222
- assign_ops = []
234
+ """Gather distributed update objs.
235
+
236
+ The objs are ops for graph execution and tensors for eager
237
+ execution.
238
+ """
239
+ assign_objs = []
223
240
224
241
for weight , mask , threshold in self ._pruning_vars :
225
242
new_threshold , new_mask = self ._maybe_update_block_mask (weight )
226
- assign_ops .append (
243
+ assign_objs .append (
227
244
distribution .extended .update (mask , update , (new_mask ,)))
228
- assign_ops .append (
245
+ assign_objs .append (
229
246
distribution .extended .update (threshold , update , (new_threshold ,)))
230
247
231
- return tf .group (assign_ops )
248
+ return tf .group (assign_objs )
232
249
233
250
return tf .cond (maybe_update_masks (), update_distributed , no_update )
234
251
0 commit comments