@@ -58,6 +58,7 @@ def __init__(self,
5858 number_of_clusters ,
5959 cluster_centroids_init = CentroidInitialization .KMEANS_PLUS_PLUS ,
6060 preserve_sparsity = False ,
61+ cluster_per_channel = False ,
6162 cluster_gradient_aggregation = GradientAggregation .SUM ,
6263 ** kwargs ):
6364 if not isinstance (layer , Layer ):
@@ -101,6 +102,17 @@ def __init__(self,
101102 # The number of cluster centroids
102103 self .number_of_clusters = number_of_clusters
103104
105+ # Whether to cluster Conv2D kernels per-channel.
106+ # In case the layer isn't a Conv2D, this isn't
107+ # applicable
108+ self .cluster_per_channel = (
109+ cluster_per_channel if isinstance (layer , tf .keras .layers .Conv2D )
110+ else False )
111+
112+ # Number of channels in a Conv2D layer, to be
113+ # used the case of per-channel clustering
114+ self .num_channels = None
115+
104116 # Whether to apply sparsity preservation or not
105117 self .preserve_sparsity = preserve_sparsity
106118
@@ -137,12 +149,31 @@ def __init__(self,
137149 hasattr (layer , '_batch_input_shape' )):
138150 self ._batch_input_shape = self .layer ._batch_input_shape
139151
152+ # In the case of Conv2D layer, the data_format
153+ # needs to be preserved to be used for per-channel
154+ # clustering
155+ if hasattr (layer , 'data_format' ):
156+ self .data_format = self .layer .data_format
157+ else :
158+ self .data_format = None
159+
140160 # Save the input shape specified in the build
141161 self .build_input_shape = None
142162
143163 def _make_layer_name (self , layer ):
144164 return '{}_{}' .format ('cluster' , layer .name )
145165
166+ def _get_zero_idx_mask (self , centroids , zero_cluster ):
167+ zero_idx_mask = (tf .cast (tf .math .not_equal (centroids ,
168+ zero_cluster ),
169+ dtype = tf .float32 ))
170+ return zero_idx_mask
171+
172+ def _get_zero_centroid (self , centroids , zero_idx_mask ):
173+ zero_centroid = tf .math .multiply (centroids ,
174+ zero_idx_mask )
175+ return zero_centroid
176+
146177 def get_weight_from_layer (self , weight_name ):
147178 return getattr (self .layer , weight_name )
148179
@@ -173,15 +204,28 @@ def build(self, input_shape):
173204 i for i , w in enumerate (self .layer .weights ) if w is original_weight )
174205 self .position_original_weights [position_original_weight ] = weight_name
175206
207+ # In the case of per-channel clustering, the number of channels,
208+ # per-channel number of clusters, as well as the overall number
209+ # of clusters all need to be preserved in the wrapper.
210+ if self .cluster_per_channel :
211+ self .num_channels = (
212+ original_weight .shape [1 ] if self .data_format == "channels_first"
213+ else original_weight .shape [- 1 ])
214+
215+ centroid_init_factory = clustering_centroids .CentroidsInitializerFactory
216+ centroid_init = centroid_init_factory .get_centroid_initializer (
217+ self .cluster_centroids_init )(
218+ weight , self .number_of_clusters ,
219+ self .cluster_per_channel ,
220+ self .num_channels ,
221+ self .preserve_sparsity )
222+
176223 # Init the cluster centroids
177- cluster_centroids = (
178- clustering_centroids .CentroidsInitializerFactory
179- .get_centroid_initializer (self .cluster_centroids_init )(
180- weight , self .number_of_clusters ,
181- self .preserve_sparsity ).get_cluster_centroids ())
224+ cluster_centroids = (centroid_init .get_cluster_centroids ())
225+
182226 self .cluster_centroids [weight_name ] = self .add_weight (
183227 '{}{}' .format ('cluster_centroids_' , weight_name ),
184- shape = (self . number_of_clusters , ),
228+ shape = (cluster_centroids . shape ),
185229 dtype = weight .dtype ,
186230 trainable = True ,
187231 initializer = tf .keras .initializers .Constant (value = cluster_centroids ))
@@ -198,10 +242,11 @@ def build(self, input_shape):
198242 weight_name_no_index = weight_name
199243 self .clustering_algorithms [weight_name ] = (
200244 clustering_registry .ClusteringLookupRegistry ().get_clustering_impl (
201- self .layer , weight_name_no_index )
245+ self .layer , weight_name_no_index , self . cluster_per_channel )
202246 (
203247 clusters_centroids = self .cluster_centroids [weight_name ],
204248 cluster_gradient_aggregation = self .cluster_gradient_aggregation ,
249+ data_format = self .data_format ,
205250 ))
206251
207252 # Init the pulling_indices (weights associations)
@@ -233,18 +278,27 @@ def update_clustered_weights_associations(self):
233278 ):
234279
235280 if self .preserve_sparsity :
236- # Set the smallest centroid to zero to force sparsity
237- # and avoid extra cluster from forming
238- zero_idx_mask = (
239- tf .cast (
240- tf .math .not_equal (
241- self .cluster_centroids [weight_name ],
242- self .cluster_centroids [weight_name ][
243- self .zero_idx [weight_name ]]),
244- dtype = tf .float32 ))
245- self .cluster_centroids [weight_name ].assign (
246- tf .math .multiply (self .cluster_centroids [weight_name ],
247- zero_idx_mask ))
281+ # In the case of per-channel clustering, sparsity
282+ # needs to be preserved per-channel
283+ if self .cluster_per_channel :
284+ for channel in range (self .num_channels ):
285+ zero_idx_mask = (
286+ self ._get_zero_idx_mask (self .cluster_centroids [weight_name ][channel ],
287+ self .cluster_centroids [weight_name ][channel ][
288+ self .zero_idx [weight_name ][channel ]]))
289+ self .cluster_centroids [weight_name ][channel ].assign (
290+ self ._get_zero_centroid (self .cluster_centroids [weight_name ][channel ],
291+ zero_idx_mask ))
292+ else :
293+ # Set the smallest centroid to zero to force sparsity
294+ # and avoid extra cluster from forming
295+ zero_idx_mask = self ._get_zero_idx_mask (self .cluster_centroids [weight_name ],
296+ self .cluster_centroids [weight_name ][
297+ self .zero_idx [weight_name ]])
298+ self .cluster_centroids [weight_name ].assign (
299+ self ._get_zero_centroid (self .cluster_centroids [weight_name ],
300+ zero_idx_mask ))
301+
248302 # During training, the original zero weights can drift slightly.
249303 # We want to prevent this by forcing them to stay zero at the places
250304 # where they were originally zero to begin with.
@@ -284,6 +338,7 @@ def get_config(self):
284338 'cluster_centroids_init' : self .cluster_centroids_init ,
285339 'preserve_sparsity' : self .preserve_sparsity ,
286340 'cluster_gradient_aggregation' : self .cluster_gradient_aggregation ,
341+ 'cluster_per_channel' : self .cluster_per_channel ,
287342 ** base_config
288343 }
289344 return config
@@ -296,12 +351,14 @@ def from_config(cls, config, custom_objects=None):
296351 cluster_centroids_init = config .pop ('cluster_centroids_init' )
297352 preserve_sparsity = config .pop ('preserve_sparsity' )
298353 cluster_gradient_aggregation = config .pop ('cluster_gradient_aggregation' )
354+ cluster_per_channel = config .pop ('cluster_per_channel' )
299355
300356 config ['number_of_clusters' ] = number_of_clusters
301357 config ['cluster_centroids_init' ] = cluster_config .CentroidInitialization (
302358 cluster_centroids_init )
303359 config ['preserve_sparsity' ] = preserve_sparsity
304360 config ['cluster_gradient_aggregation' ] = cluster_gradient_aggregation
361+ config ['cluster_per_channel' ] = cluster_per_channel
305362
306363 layer = tf .keras .layers .deserialize (
307364 config .pop ('layer' ), custom_objects = custom_objects )
0 commit comments