Skip to content

Commit f3cbc55

Browse files
Merge pull request #860 from psunn:psunn_quickfix_mbyn
PiperOrigin-RevId: 401647817
2 parents 6569b35 + 162ddc7 commit f3cbc55

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,11 @@ def conditional_mask_update(self):
240240
"""Returns an op to updates masks as per the pruning schedule."""
241241

242242
def maybe_update_masks():
243-
return self._pruning_schedule(self._step_fn())[0]
243+
if self._sparsity_m_by_n:
244+
# Update structured sparsity masks only at step 1
245+
return tf.math.equal(self._step_fn(), 1)
246+
else:
247+
return self._pruning_schedule(self._step_fn())[0]
244248

245249
def no_update():
246250
return tf.no_op()

0 commit comments

Comments
 (0)