@@ -152,7 +152,8 @@ def testValuesRemainClusteredAfterTraining(self):
152
152
153
153
@keras_parameterized .run_all_keras_modes
154
154
def testSparsityIsPreservedDuringTraining (self ):
155
- """Set a specific random seed to ensure that we get some null weights to test sparsity preservation with."""
155
+ # Set a specific random seed to ensure that we get some null weights to
156
+ # test sparsity preservation with.
156
157
tf .random .set_seed (1 )
157
158
158
159
# Verifies that training a clustered model does not destroy the sparsity of
@@ -187,7 +188,8 @@ def testSparsityIsPreservedDuringTraining(self):
187
188
stripped_model_after_tuning = cluster .strip_clustering (clustered_model )
188
189
weights_after_tuning = stripped_model_after_tuning .layers [0 ].kernel
189
190
non_zero_weight_indices_after_tuning = np .nonzero (weights_after_tuning )
190
- weights_as_list_after_tuning = weights_after_tuning .numpy ().reshape (- 1 ,).tolist ()
191
+ weights_as_list_after_tuning = weights_after_tuning .numpy ().reshape (
192
+ - 1 ,).tolist ()
191
193
unique_weights_after_tuning = set (weights_as_list_after_tuning )
192
194
193
195
# Check that the null weights stayed the same before and after tuning.
@@ -245,15 +247,15 @@ def testEndToEndDeepLayer(self):
245
247
246
248
def clusters_check (stripped_model ):
247
249
# inner dense layer
248
- weights_as_list = stripped_model . submodules [ 1 ]. trainable_weights [ 0 ].\
249
- numpy ().flatten ()
250
+ weights_as_list = (
251
+ stripped_model . submodules [ 1 ]. trainable_weights [ 0 ]. numpy ().flatten () )
250
252
unique_weights = set (weights_as_list )
251
253
self .assertLessEqual (
252
254
len (unique_weights ), self .params ["number_of_clusters" ])
253
255
254
256
# outer dense layer
255
- weights_as_list = stripped_model . submodules [ 4 ]. trainable_weights [ 0 ].\
256
- numpy ().flatten ()
257
+ weights_as_list = (
258
+ stripped_model . submodules [ 4 ]. trainable_weights [ 0 ]. numpy ().flatten () )
257
259
unique_weights = set (weights_as_list )
258
260
self .assertLessEqual (
259
261
len (unique_weights ), self .params ["number_of_clusters" ])
@@ -276,23 +278,22 @@ def testEndToEndDeepLayer2(self):
276
278
277
279
def clusters_check (stripped_model ):
278
280
# first inner dense layer
279
- weights_as_list = stripped_model . submodules [ 1 ]. trainable_weights [ 0 ].\
280
- numpy ().flatten ()
281
+ weights_as_list = (
282
+ stripped_model . submodules [ 1 ]. trainable_weights [ 0 ]. numpy ().flatten () )
281
283
unique_weights = set (weights_as_list )
282
284
self .assertLessEqual (
283
285
len (unique_weights ), self .params ["number_of_clusters" ])
284
286
285
287
# second inner dense layer
286
- weights_as_list = stripped_model .submodules [4 ].\
287
- trainable_weights [0 ].\
288
- numpy ().flatten ()
288
+ weights_as_list = (
289
+ stripped_model .submodules [4 ].trainable_weights [0 ].numpy ().flatten ())
289
290
unique_weights = set (weights_as_list )
290
291
self .assertLessEqual (
291
292
len (unique_weights ), self .params ["number_of_clusters" ])
292
293
293
294
# outer dense layer
294
- weights_as_list = stripped_model . submodules [ 7 ]. trainable_weights [ 0 ].\
295
- numpy ().flatten ()
295
+ weights_as_list = (
296
+ stripped_model . submodules [ 7 ]. trainable_weights [ 0 ]. numpy ().flatten () )
296
297
unique_weights = set (weights_as_list )
297
298
self .assertLessEqual (
298
299
len (unique_weights ), self .params ["number_of_clusters" ])
@@ -301,51 +302,46 @@ def clusters_check(stripped_model):
301
302
302
303
@keras_parameterized .run_all_keras_modes
303
304
def testWeightsAreLearningDuringClustering (self ):
304
- """Verifies that training a clustered model does update
305
- original_weights, clustered_centroids and bias."""
306
- original_model = keras .Sequential ([
307
- layers .Dense (5 , input_shape = (5 ,))
308
- ])
305
+ """Verifies that weights are updated during training a clustered model.
306
+
307
+ Training a clustered model should update original_weights,
308
+ clustered_centroids and bias.
309
+ """
310
+ original_model = keras .Sequential ([layers .Dense (5 , input_shape = (5 ,))])
309
311
310
312
clustered_model = cluster .cluster_weights (original_model , ** self .params )
311
313
312
314
clustered_model .compile (
313
- loss = keras .losses .categorical_crossentropy ,
314
- optimizer = "adam" ,
315
- metrics = ["accuracy" ],
315
+ loss = keras .losses .categorical_crossentropy ,
316
+ optimizer = "adam" ,
317
+ metrics = ["accuracy" ],
316
318
)
317
319
318
320
class CheckWeightsCallback (keras .callbacks .Callback ):
321
+
319
322
def on_train_batch_begin (self , batch , logs = None ):
320
323
# Save weights before batch
321
324
self .original_weight_kernel = (
322
- self .model .layers [0 ].original_clusterable_weights ['kernel' ].numpy ()
323
- )
325
+ self .model .layers [0 ].original_clusterable_weights ["kernel" ].numpy ())
324
326
self .cluster_centroids_kernel = (
325
- self .model .layers [0 ].cluster_centroids ['kernel' ].numpy ()
326
- )
327
- self .bias = (
328
- self .model .layers [0 ].layer .bias .numpy ()
329
- )
327
+ self .model .layers [0 ].cluster_centroids ["kernel" ].numpy ())
328
+ self .bias = (self .model .layers [0 ].layer .bias .numpy ())
330
329
331
330
def on_train_batch_end (self , batch , logs = None ):
332
331
# Check weights are different after batch
333
332
assert not np .array_equal (
334
- self .original_weight_kernel ,
335
- self .model .layers [0 ].original_clusterable_weights ['kernel' ].numpy ()
336
- )
333
+ self .original_weight_kernel ,
334
+ self .model .layers [0 ].original_clusterable_weights ["kernel" ].numpy ())
337
335
assert not np .array_equal (
338
- self .cluster_centroids_kernel ,
339
- self .model .layers [0 ].cluster_centroids ['kernel' ].numpy ()
340
- )
341
- assert not np .array_equal (
342
- self .bias ,
343
- self .model .layers [0 ].layer .bias .numpy ()
344
- )
345
-
346
- clustered_model .fit (x = self .dataset_generator (),
347
- steps_per_epoch = 5 ,
348
- callbacks = [CheckWeightsCallback ()])
336
+ self .cluster_centroids_kernel ,
337
+ self .model .layers [0 ].cluster_centroids ["kernel" ].numpy ())
338
+ assert not np .array_equal (self .bias ,
339
+ self .model .layers [0 ].layer .bias .numpy ())
340
+
341
+ clustered_model .fit (
342
+ x = self .dataset_generator (),
343
+ steps_per_epoch = 5 ,
344
+ callbacks = [CheckWeightsCallback ()])
349
345
350
346
351
347
if __name__ == "__main__" :
0 commit comments