@@ -98,15 +98,30 @@ def _update_mask(self, weights):
9898 (1 - sparsity )),
9999 1 ),
100100 tf .int32 )
101- # Sort the entire array
102- values , _ = tf .math .top_k (
101+ values , indices = tf .math .top_k (
103102 tf .reshape (abs_weights , [- 1 ]), k = tf .size (abs_weights ))
104- # Grab the (k-1)th value
105103
106- current_threshold = tf .gather (values , k - 1 )
107- new_mask = tf .dtypes .cast (
108- tf .math .greater_equal (abs_weights , current_threshold ), weights .dtype )
109- return current_threshold , new_mask
104+ # Grab the (k-1)th value as a threshold to build pruning mask.
105+ threshold_value = tf .gather (values , k - 1 )
106+ threshold_pos = tf .gather (indices , k - 1 )
107+
108+ # Build mask for the weight element higher magnitude than threshold value.
109+ # A mask is added to make sure the threshold element be incorporated.
110+ # TODO(b/208967539): Update the logic to index oriented logic.
111+ new_mask = tf .math .logical_or (
112+ tf .math .greater_equal (abs_weights , threshold_value ),
113+ tf .reshape (
114+ tf .one_hot (
115+ threshold_pos ,
116+ depth = tf .size (abs_weights ),
117+ on_value = True ,
118+ off_value = False ,
119+ dtype = tf .bool ), abs_weights .shape ))
120+
121+ # Updated mask is casted back to weight's data type in case of the type
122+ # mismatching due to keras mixed precision policy.
123+ return tf .dtypes .cast (threshold_value , weights .dtype ), tf .dtypes .cast (
124+ new_mask , weights .dtype )
110125
111126 def _update_mask_sparsity_m_by_n (self , weights , m_by_n = (2 , 4 )):
112127 """Updates the m by n sparsity mask for a given weight tensor.
@@ -221,14 +236,16 @@ def update_var(variable, reduced_value):
221236 if tf .distribute .get_replica_context ():
222237 values_and_vars = []
223238 for weight , mask , _ in self ._pruning_vars :
224- masked_weight = tf .math .multiply (weight , mask )
239+ masked_weight = tf .dtypes .cast (
240+ tf .math .multiply (weight , mask ), dtype = weight .dtype )
225241 values_and_vars .append ((masked_weight , weight ))
226242 if values_and_vars :
227243 assign_objs .append (tf .distribute .get_replica_context ().merge_call (
228244 update_fn , args = (values_and_vars ,)))
229245 else :
230246 for weight , mask , _ in self ._pruning_vars :
231- masked_weight = tf .math .multiply (weight , mask )
247+ masked_weight = tf .dtypes .cast (
248+ tf .math .multiply (weight , mask ), dtype = weight .dtype )
232249 assign_objs .append (tf_compat .assign (weight , masked_weight ))
233250
234251 return assign_objs
0 commit comments