We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent d43b950 commit ab96106Copy full SHA for ab96106
tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl.py
@@ -102,7 +102,7 @@ def _update_mask(self, weights):
102
current_threshold = array_ops.gather(values, k - 1)
103
new_mask = math_ops.cast(
104
math_ops.greater_equal(abs_weights, current_threshold),
105
- dtypes.float32)
+ weights.dtype)
106
return current_threshold, new_mask
107
108
def _maybe_update_block_mask(self, weights):
0 commit comments