@@ -152,7 +152,8 @@ def testValuesRemainClusteredAfterTraining(self):
152152
153153 @keras_parameterized .run_all_keras_modes
154154 def testSparsityIsPreservedDuringTraining (self ):
155- """Set a specific random seed to ensure that we get some null weights to test sparsity preservation with."""
155+ # Set a specific random seed to ensure that we get some null weights to
156+ # test sparsity preservation with.
156157 tf .random .set_seed (1 )
157158
158159 # Verifies that training a clustered model does not destroy the sparsity of
@@ -187,7 +188,8 @@ def testSparsityIsPreservedDuringTraining(self):
187188 stripped_model_after_tuning = cluster .strip_clustering (clustered_model )
188189 weights_after_tuning = stripped_model_after_tuning .layers [0 ].kernel
189190 non_zero_weight_indices_after_tuning = np .nonzero (weights_after_tuning )
190- weights_as_list_after_tuning = weights_after_tuning .numpy ().reshape (- 1 ,).tolist ()
191+ weights_as_list_after_tuning = weights_after_tuning .numpy ().reshape (
192+ - 1 ,).tolist ()
191193 unique_weights_after_tuning = set (weights_as_list_after_tuning )
192194
193195 # Check that the null weights stayed the same before and after tuning.
@@ -245,15 +247,15 @@ def testEndToEndDeepLayer(self):
245247
246248 def clusters_check (stripped_model ):
247249 # inner dense layer
248- weights_as_list = stripped_model . submodules [ 1 ]. trainable_weights [ 0 ].\
249- numpy ().flatten ()
250+ weights_as_list = (
251+ stripped_model . submodules [ 1 ]. trainable_weights [ 0 ]. numpy ().flatten () )
250252 unique_weights = set (weights_as_list )
251253 self .assertLessEqual (
252254 len (unique_weights ), self .params ["number_of_clusters" ])
253255
254256 # outer dense layer
255- weights_as_list = stripped_model . submodules [ 4 ]. trainable_weights [ 0 ].\
256- numpy ().flatten ()
257+ weights_as_list = (
258+ stripped_model . submodules [ 4 ]. trainable_weights [ 0 ]. numpy ().flatten () )
257259 unique_weights = set (weights_as_list )
258260 self .assertLessEqual (
259261 len (unique_weights ), self .params ["number_of_clusters" ])
@@ -276,23 +278,22 @@ def testEndToEndDeepLayer2(self):
276278
277279 def clusters_check (stripped_model ):
278280 # first inner dense layer
279- weights_as_list = stripped_model . submodules [ 1 ]. trainable_weights [ 0 ].\
280- numpy ().flatten ()
281+ weights_as_list = (
282+ stripped_model . submodules [ 1 ]. trainable_weights [ 0 ]. numpy ().flatten () )
281283 unique_weights = set (weights_as_list )
282284 self .assertLessEqual (
283285 len (unique_weights ), self .params ["number_of_clusters" ])
284286
285287 # second inner dense layer
286- weights_as_list = stripped_model .submodules [4 ].\
287- trainable_weights [0 ].\
288- numpy ().flatten ()
288+ weights_as_list = (
289+ stripped_model .submodules [4 ].trainable_weights [0 ].numpy ().flatten ())
289290 unique_weights = set (weights_as_list )
290291 self .assertLessEqual (
291292 len (unique_weights ), self .params ["number_of_clusters" ])
292293
293294 # outer dense layer
294- weights_as_list = stripped_model . submodules [ 7 ]. trainable_weights [ 0 ].\
295- numpy ().flatten ()
295+ weights_as_list = (
296+ stripped_model . submodules [ 7 ]. trainable_weights [ 0 ]. numpy ().flatten () )
296297 unique_weights = set (weights_as_list )
297298 self .assertLessEqual (
298299 len (unique_weights ), self .params ["number_of_clusters" ])
@@ -301,51 +302,46 @@ def clusters_check(stripped_model):
301302
302303 @keras_parameterized .run_all_keras_modes
303304 def testWeightsAreLearningDuringClustering (self ):
304- """Verifies that training a clustered model does update
305- original_weights, clustered_centroids and bias."""
306- original_model = keras .Sequential ([
307- layers .Dense (5 , input_shape = (5 ,))
308- ])
305+ """Verifies that weights are updated during training a clustered model.
306+
307+ Training a clustered model should update original_weights,
308+ clustered_centroids and bias.
309+ """
310+ original_model = keras .Sequential ([layers .Dense (5 , input_shape = (5 ,))])
309311
310312 clustered_model = cluster .cluster_weights (original_model , ** self .params )
311313
312314 clustered_model .compile (
313- loss = keras .losses .categorical_crossentropy ,
314- optimizer = "adam" ,
315- metrics = ["accuracy" ],
315+ loss = keras .losses .categorical_crossentropy ,
316+ optimizer = "adam" ,
317+ metrics = ["accuracy" ],
316318 )
317319
318320 class CheckWeightsCallback (keras .callbacks .Callback ):
321+
319322 def on_train_batch_begin (self , batch , logs = None ):
320323 # Save weights before batch
321324 self .original_weight_kernel = (
322- self .model .layers [0 ].original_clusterable_weights ['kernel' ].numpy ()
323- )
325+ self .model .layers [0 ].original_clusterable_weights ["kernel" ].numpy ())
324326 self .cluster_centroids_kernel = (
325- self .model .layers [0 ].cluster_centroids ['kernel' ].numpy ()
326- )
327- self .bias = (
328- self .model .layers [0 ].layer .bias .numpy ()
329- )
327+ self .model .layers [0 ].cluster_centroids ["kernel" ].numpy ())
328+ self .bias = (self .model .layers [0 ].layer .bias .numpy ())
330329
331330 def on_train_batch_end (self , batch , logs = None ):
332331 # Check weights are different after batch
333332 assert not np .array_equal (
334- self .original_weight_kernel ,
335- self .model .layers [0 ].original_clusterable_weights ['kernel' ].numpy ()
336- )
333+ self .original_weight_kernel ,
334+ self .model .layers [0 ].original_clusterable_weights ["kernel" ].numpy ())
337335 assert not np .array_equal (
338- self .cluster_centroids_kernel ,
339- self .model .layers [0 ].cluster_centroids ['kernel' ].numpy ()
340- )
341- assert not np .array_equal (
342- self .bias ,
343- self .model .layers [0 ].layer .bias .numpy ()
344- )
345-
346- clustered_model .fit (x = self .dataset_generator (),
347- steps_per_epoch = 5 ,
348- callbacks = [CheckWeightsCallback ()])
336+ self .cluster_centroids_kernel ,
337+ self .model .layers [0 ].cluster_centroids ["kernel" ].numpy ())
338+ assert not np .array_equal (self .bias ,
339+ self .model .layers [0 ].layer .bias .numpy ())
340+
341+ clustered_model .fit (
342+ x = self .dataset_generator (),
343+ steps_per_epoch = 5 ,
344+ callbacks = [CheckWeightsCallback ()])
349345
350346
351347if __name__ == "__main__" :
0 commit comments