We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents 6569b35 + 162ddc7 commit f3cbc55Copy full SHA for f3cbc55
tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl.py
@@ -240,7 +240,11 @@ def conditional_mask_update(self):
240
"""Returns an op to updates masks as per the pruning schedule."""
241
242
def maybe_update_masks():
243
- return self._pruning_schedule(self._step_fn())[0]
+ if self._sparsity_m_by_n:
244
+ # Update structured sparsity masks only at step 1
245
+ return tf.math.equal(self._step_fn(), 1)
246
+ else:
247
+ return self._pruning_schedule(self._step_fn())[0]
248
249
def no_update():
250
return tf.no_op()
0 commit comments