Skip to content

Commit ab96106

Browse files
alanchiaotensorflower-gardener
authored andcommitted
Update mask using type that was used during variable initialization in wrapper.
PiperOrigin-RevId: 285501629
1 parent d43b950 commit ab96106

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def _update_mask(self, weights):
102102
current_threshold = array_ops.gather(values, k - 1)
103103
new_mask = math_ops.cast(
104104
math_ops.greater_equal(abs_weights, current_threshold),
105-
dtypes.float32)
105+
weights.dtype)
106106
return current_threshold, new_mask
107107

108108
def _maybe_update_block_mask(self, weights):

0 commit comments

Comments
 (0)