Skip to content

Commit 1bec520

Browse files
rino20tensorflower-gardener
authored andcommitted
Resolve type mismatching bugs between given model weights and pruning variables of pruning api
PiperOrigin-RevId: 414446051
1 parent f19b4e6 commit 1bec520

File tree

1 file changed

+26
-9
lines changed

1 file changed

+26
-9
lines changed

tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -98,15 +98,30 @@ def _update_mask(self, weights):
9898
(1 - sparsity)),
9999
1),
100100
tf.int32)
101-
# Sort the entire array
102-
values, _ = tf.math.top_k(
101+
values, indices = tf.math.top_k(
103102
tf.reshape(abs_weights, [-1]), k=tf.size(abs_weights))
104-
# Grab the (k-1)th value
105103

106-
current_threshold = tf.gather(values, k - 1)
107-
new_mask = tf.dtypes.cast(
108-
tf.math.greater_equal(abs_weights, current_threshold), weights.dtype)
109-
return current_threshold, new_mask
104+
# Grab the (k-1)th value as a threshold to build pruning mask.
105+
threshold_value = tf.gather(values, k - 1)
106+
threshold_pos = tf.gather(indices, k - 1)
107+
108+
# Build mask for the weight element higher magnitude than threshold value.
109+
# A mask is added to make sure the threshold element be incorporated.
110+
# TODO(b/208967539): Update the logic to index oriented logic.
111+
new_mask = tf.math.logical_or(
112+
tf.math.greater_equal(abs_weights, threshold_value),
113+
tf.reshape(
114+
tf.one_hot(
115+
threshold_pos,
116+
depth=tf.size(abs_weights),
117+
on_value=True,
118+
off_value=False,
119+
dtype=tf.bool), abs_weights.shape))
120+
121+
# Updated mask is casted back to weight's data type in case of the type
122+
# mismatching due to keras mixed precision policy.
123+
return tf.dtypes.cast(threshold_value, weights.dtype), tf.dtypes.cast(
124+
new_mask, weights.dtype)
110125

111126
def _update_mask_sparsity_m_by_n(self, weights, m_by_n=(2, 4)):
112127
"""Updates the m by n sparsity mask for a given weight tensor.
@@ -221,14 +236,16 @@ def update_var(variable, reduced_value):
221236
if tf.distribute.get_replica_context():
222237
values_and_vars = []
223238
for weight, mask, _ in self._pruning_vars:
224-
masked_weight = tf.math.multiply(weight, mask)
239+
masked_weight = tf.dtypes.cast(
240+
tf.math.multiply(weight, mask), dtype=weight.dtype)
225241
values_and_vars.append((masked_weight, weight))
226242
if values_and_vars:
227243
assign_objs.append(tf.distribute.get_replica_context().merge_call(
228244
update_fn, args=(values_and_vars,)))
229245
else:
230246
for weight, mask, _ in self._pruning_vars:
231-
masked_weight = tf.math.multiply(weight, mask)
247+
masked_weight = tf.dtypes.cast(
248+
tf.math.multiply(weight, mask), dtype=weight.dtype)
232249
assign_objs.append(tf_compat.assign(weight, masked_weight))
233250

234251
return assign_objs

0 commit comments

Comments
 (0)