2121from absl .testing import parameterized
2222import tensorflow as tf
2323
24- from tensorflow .python .keras import keras_parameterized
2524from tensorflow_model_optimization .python .core .clustering .keras import cluster
2625from tensorflow_model_optimization .python .core .clustering .keras import cluster_config
2726from tensorflow_model_optimization .python .core .clustering .keras import cluster_wrapper
@@ -162,15 +161,13 @@ def _count_clustered_layers(self, model):
162161 count += 1
163162 return count
164163
165- @keras_parameterized .run_all_keras_modes
166164 def testClusterKerasClusterableLayer (self ):
167165 """Verifies that a built-in keras layer marked as clusterable is being clustered correctly."""
168166 wrapped_layer = self ._build_clustered_layer_model (
169167 self .keras_clusterable_layer )
170168
171169 self ._validate_clustered_layer (self .keras_clusterable_layer , wrapped_layer )
172170
173- @keras_parameterized .run_all_keras_modes
174171 def testClusterKerasClusterableLayerWithSparsityPreservation (self ):
175172 """Verifies that a built-in keras layer marked as clusterable is being clustered correctly when sparsity preservation is enabled."""
176173 preserve_sparsity_params = {'preserve_sparsity' : True }
@@ -180,7 +177,6 @@ def testClusterKerasClusterableLayerWithSparsityPreservation(self):
180177
181178 self ._validate_clustered_layer (self .keras_clusterable_layer , wrapped_layer )
182179
183- @keras_parameterized .run_all_keras_modes
184180 def testClusterKerasNonClusterableLayer (self ):
185181 """Verifies that a built-in keras layer not marked as clusterable is not being clustered."""
186182 wrapped_layer = self ._build_clustered_layer_model (
@@ -190,7 +186,6 @@ def testClusterKerasNonClusterableLayer(self):
190186 wrapped_layer )
191187 self .assertEqual ([], wrapped_layer .layer .get_clusterable_weights ())
192188
193- @keras_parameterized .run_all_keras_modes
194189 def testDepthwiseConv2DLayerNonClusterable (self ):
195190 """Verifies that we don't cluster a DepthwiseConv2D layer, because clustering of this type of layer gives big unrecoverable accuracy loss."""
196191 wrapped_layer = self ._build_clustered_layer_model (
@@ -200,7 +195,6 @@ def testDepthwiseConv2DLayerNonClusterable(self):
200195 wrapped_layer )
201196 self .assertEqual ([], wrapped_layer .layer .get_clusterable_weights ())
202197
203- @keras_parameterized .run_all_keras_modes
204198 def testDenseLayer (self ):
205199 """Verifies that we can cluster a Dense layer."""
206200 input_shape = (28 , 1 )
@@ -214,7 +208,6 @@ def testDenseLayer(self):
214208 self .assertEqual ([1 , 10 ],
215209 wrapped_layer .layer .get_clusterable_weights ()[0 ][1 ].shape )
216210
217- @keras_parameterized .run_all_keras_modes
218211 def testConv1DLayer (self ):
219212 """Verifies that we can cluster a Conv1D layer."""
220213 input_shape = (28 , 1 )
@@ -227,7 +220,6 @@ def testConv1DLayer(self):
227220 self .assertEqual ([5 , 1 , 3 ],
228221 wrapped_layer .layer .get_clusterable_weights ()[0 ][1 ].shape )
229222
230- @keras_parameterized .run_all_keras_modes
231223 def testConv1DTransposeLayer (self ):
232224 """Verifies that we can cluster a Conv1DTranspose layer."""
233225 input_shape = (28 , 1 )
@@ -240,7 +232,6 @@ def testConv1DTransposeLayer(self):
240232 self .assertEqual ([5 , 3 , 1 ],
241233 wrapped_layer .layer .get_clusterable_weights ()[0 ][1 ].shape )
242234
243- @keras_parameterized .run_all_keras_modes
244235 def testConv2DLayer (self ):
245236 """Verifies that we can cluster a Conv2D layer."""
246237 input_shape = (28 , 28 , 1 )
@@ -253,7 +244,6 @@ def testConv2DLayer(self):
253244 self .assertEqual ([4 , 5 , 1 , 3 ],
254245 wrapped_layer .layer .get_clusterable_weights ()[0 ][1 ].shape )
255246
256- @keras_parameterized .run_all_keras_modes
257247 def testConv2DTransposeLayer (self ):
258248 """Verifies that we can cluster a Conv2DTranspose layer."""
259249 input_shape = (28 , 28 , 1 )
@@ -266,7 +256,6 @@ def testConv2DTransposeLayer(self):
266256 self .assertEqual ([4 , 5 , 3 , 1 ],
267257 wrapped_layer .layer .get_clusterable_weights ()[0 ][1 ].shape )
268258
269- @keras_parameterized .run_all_keras_modes
270259 def testConv3DLayer (self ):
271260 """Verifies that we can cluster a Conv3D layer."""
272261 input_shape = (28 , 28 , 28 , 1 )
@@ -287,7 +276,6 @@ def testClusterKerasUnsupportedLayer(self):
287276 with self .assertRaises (ValueError ):
288277 cluster .cluster_weights (keras_unsupported_layer , ** self .params )
289278
290- @keras_parameterized .run_all_keras_modes
291279 def testClusterCustomClusterableLayer (self ):
292280 """Verifies that a custom clusterable layer is being clustered correctly."""
293281 wrapped_layer = self ._build_clustered_layer_model (
@@ -297,7 +285,6 @@ def testClusterCustomClusterableLayer(self):
297285 self .assertEqual ([('kernel' , wrapped_layer .layer .kernel )],
298286 wrapped_layer .layer .get_clusterable_weights ())
299287
300- @keras_parameterized .run_all_keras_modes
301288 def testClusterCustomClusterableLayerWithSparsityPreservation (self ):
302289 """Verifies that a custom clusterable layer is being clustered correctly when sparsity preservation is enabled."""
303290 preserve_sparsity_params = {'preserve_sparsity' : True }
@@ -424,7 +411,6 @@ def testStripClusteringSequentialModelWithBiasConstraint(self):
424411 keras_file = os .path .join (tmp_dir_name , 'cluster_test' )
425412 stripped_model .save (keras_file , save_traces = True )
426413
427- @keras_parameterized .run_all_keras_modes
428414 def testClusterSequentialModelSelectively (self ):
429415 clustered_model = keras .Sequential ()
430416 clustered_model .add (
@@ -437,7 +423,6 @@ def testClusterSequentialModelSelectively(self):
437423 self .assertNotIsInstance (clustered_model .layers [1 ],
438424 cluster_wrapper .ClusterWeights )
439425
440- @keras_parameterized .run_all_keras_modes
441426 def testClusterSequentialModelSelectivelyWithSparsityPreservation (self ):
442427 """Verifies that layers within a sequential model can be clustered selectively when sparsity preservation is enabled."""
443428 preserve_sparsity_params = {'preserve_sparsity' : True }
@@ -454,7 +439,6 @@ def testClusterSequentialModelSelectivelyWithSparsityPreservation(self):
454439 self .assertNotIsInstance (clustered_model .layers [1 ],
455440 cluster_wrapper .ClusterWeights )
456441
457- @keras_parameterized .run_all_keras_modes
458442 def testClusterFunctionalModelSelectively (self ):
459443 """Verifies that layers within a functional model can be clustered selectively."""
460444 i1 = keras .Input (shape = (10 ,))
@@ -469,7 +453,6 @@ def testClusterFunctionalModelSelectively(self):
469453 self .assertNotIsInstance (clustered_model .layers [3 ],
470454 cluster_wrapper .ClusterWeights )
471455
472- @keras_parameterized .run_all_keras_modes
473456 def testClusterFunctionalModelSelectivelyWithSparsityPreservation (self ):
474457 """Verifies that layers within a functional model can be clustered selectively when sparsity preservation is enabled."""
475458 preserve_sparsity_params = {'preserve_sparsity' : True }
@@ -486,7 +469,6 @@ def testClusterFunctionalModelSelectivelyWithSparsityPreservation(self):
486469 self .assertNotIsInstance (clustered_model .layers [3 ],
487470 cluster_wrapper .ClusterWeights )
488471
489- @keras_parameterized .run_all_keras_modes
490472 def testClusterModelValidLayersSuccessful (self ):
491473 """Verifies that clustering a sequential model results in all clusterable layers within the model being clustered."""
492474 model = keras .Sequential ([
@@ -500,7 +482,6 @@ def testClusterModelValidLayersSuccessful(self):
500482 for layer , clustered_layer in zip (model .layers , clustered_model .layers ):
501483 self ._validate_clustered_layer (layer , clustered_layer )
502484
503- @keras_parameterized .run_all_keras_modes
504485 def testClusterModelValidLayersSuccessfulWithSparsityPreservation (self ):
505486 """Verifies that clustering a sequential model results in all clusterable layers within the model being clustered when sparsity preservation is enabled."""
506487 preserve_sparsity_params = {'preserve_sparsity' : True }
@@ -540,7 +521,6 @@ def testClusterModelCustomNonClusterableLayerRaisesError(self):
540521 self .custom_clusterable_layer , custom_non_clusterable_layer
541522 ]), ** self .params )
542523
543- @keras_parameterized .run_all_keras_modes
544524 def testClusterModelDoesNotWrapAlreadyWrappedLayer (self ):
545525 """Verifies that clustering a model that contains an already clustered layer does not result in wrapping the clustered layer into another cluster_wrapper."""
546526 model = keras .Sequential ([
@@ -579,7 +559,6 @@ def testClusterSequentialModelNoInput(self):
579559 clustered_model = cluster .cluster_weights (model , ** self .params )
580560 self .assertEqual (self ._count_clustered_layers (clustered_model ), 2 )
581561
582- @keras_parameterized .run_all_keras_modes
583562 def testClusterSequentialModelWithInput (self ):
584563 """Verifies that a sequential model with an input layer is being clustered correctly."""
585564 # With InputLayer
@@ -607,7 +586,6 @@ def testClusterSequentialModelPreservesBuiltStateNoInput(self):
607586 json .loads (clustered_model .to_json ()))
608587 self .assertEqual (loaded_model .built , False )
609588
610- @keras_parameterized .run_all_keras_modes
611589 def testClusterSequentialModelPreservesBuiltStateWithInput (self ):
612590 """Verifies that clustering a sequential model with an input layer preserves the built state of the model."""
613591 # With InputLayer
@@ -625,7 +603,6 @@ def testClusterSequentialModelPreservesBuiltStateWithInput(self):
625603 json .loads (clustered_model .to_json ()))
626604 self .assertEqual (loaded_model .built , True )
627605
628- @keras_parameterized .run_all_keras_modes
629606 def testClusterFunctionalModelPreservesBuiltState (self ):
630607 """Verifies that clustering a functional model preserves the built state of the model."""
631608 i1 = keras .Input (shape = (10 ,))
@@ -644,7 +621,6 @@ def testClusterFunctionalModelPreservesBuiltState(self):
644621 json .loads (clustered_model .to_json ()))
645622 self .assertEqual (loaded_model .built , True )
646623
647- @keras_parameterized .run_all_keras_modes
648624 def testClusterFunctionalModel (self ):
649625 """Verifies that a functional model is being clustered correctly."""
650626 i1 = keras .Input (shape = (10 ,))
@@ -656,7 +632,6 @@ def testClusterFunctionalModel(self):
656632 clustered_model = cluster .cluster_weights (model , ** self .params )
657633 self .assertEqual (self ._count_clustered_layers (clustered_model ), 3 )
658634
659- @keras_parameterized .run_all_keras_modes
660635 def testClusterFunctionalModelWithLayerReused (self ):
661636 """Verifies that a layer reused within a functional model multiple times is only being clustered once."""
662637 # The model reuses the Dense() layer. Make sure it's only clustered once.
@@ -668,22 +643,19 @@ def testClusterFunctionalModelWithLayerReused(self):
668643 clustered_model = cluster .cluster_weights (model , ** self .params )
669644 self .assertEqual (self ._count_clustered_layers (clustered_model ), 1 )
670645
671- @keras_parameterized .run_all_keras_modes
672646 def testClusterSubclassModel (self ):
673647 """Verifies that attempting to cluster an instance of a subclass of keras.Model raises an exception."""
674648 model = TestModel ()
675649 with self .assertRaises (ValueError ):
676650 _ = cluster .cluster_weights (model , ** self .params )
677651
678- @keras_parameterized .run_all_keras_modes
679652 def testClusterSubclassModelAsSubmodel (self ):
680653 """Verifies that attempting to cluster a model with submodel that is a subclass throws an exception."""
681654 model_subclass = TestModel ()
682655 model = keras .Sequential ([layers .Dense (10 ), model_subclass ])
683656 with self .assertRaisesRegex (ValueError , 'Subclassed models.*' ):
684657 _ = cluster .cluster_weights (model , ** self .params )
685658
686- @keras_parameterized .run_all_keras_modes
687659 def testStripClusteringSequentialModel (self ):
688660 """Verifies that stripping the clustering wrappers from a sequential model produces the expected config."""
689661 model = keras .Sequential ([
@@ -697,7 +669,6 @@ def testStripClusteringSequentialModel(self):
697669 self .assertEqual (self ._count_clustered_layers (stripped_model ), 0 )
698670 self .assertEqual (model .get_config (), stripped_model .get_config ())
699671
700- @keras_parameterized .run_all_keras_modes
701672 def testClusterStrippingFunctionalModel (self ):
702673 """Verifies that stripping the clustering wrappers from a functional model produces the expected config."""
703674 i1 = keras .Input (shape = (10 ,))
@@ -713,7 +684,6 @@ def testClusterStrippingFunctionalModel(self):
713684 self .assertEqual (self ._count_clustered_layers (stripped_model ), 0 )
714685 self .assertEqual (model .get_config (), stripped_model .get_config ())
715686
716- @keras_parameterized .run_all_keras_modes
717687 def testClusterWeightsStrippedWeights (self ):
718688 """Verifies that stripping the clustering wrappers from a functional model preserves the clustered weights."""
719689 i1 = keras .Input (shape = (10 ,))
@@ -728,7 +698,6 @@ def testClusterWeightsStrippedWeights(self):
728698 self .assertEqual (self ._count_clustered_layers (stripped_model ), 0 )
729699 self .assertLen (stripped_model .get_weights (), cluster_weight_length )
730700
731- @keras_parameterized .run_all_keras_modes
732701 def testStrippedKernel (self ):
733702 """Verifies that stripping the clustering wrappers from a functional model restores the layers kernel and the layers weight array to the new clustered weight value."""
734703 i1 = keras .Input (shape = (1 , 1 , 1 ))
@@ -746,7 +715,6 @@ def testStrippedKernel(self):
746715 self .assertIsNot (stripped_conv2d_layer .kernel , clustered_kernel )
747716 self .assertIn (stripped_conv2d_layer .kernel , stripped_conv2d_layer .weights )
748717
749- @keras_parameterized .run_all_keras_modes
750718 def testStripSelectivelyClusteredFunctionalModel (self ):
751719 """Verifies that invoking strip_clustering() on a selectively clustered functional model strips the clustering wrappers from the clustered layers."""
752720 i1 = keras .Input (shape = (10 ,))
@@ -761,7 +729,6 @@ def testStripSelectivelyClusteredFunctionalModel(self):
761729 self .assertEqual (self ._count_clustered_layers (stripped_model ), 0 )
762730 self .assertIsInstance (stripped_model .layers [2 ], layers .Dense )
763731
764- @keras_parameterized .run_all_keras_modes
765732 def testStripSelectivelyClusteredSequentialModel (self ):
766733 """Verifies that invoking strip_clustering() on a selectively clustered sequential model strips the clustering wrappers from the clustered layers."""
767734 clustered_model = keras .Sequential ([
@@ -775,7 +742,6 @@ def testStripSelectivelyClusteredSequentialModel(self):
775742 self .assertEqual (self ._count_clustered_layers (stripped_model ), 0 )
776743 self .assertIsInstance (stripped_model .layers [0 ], layers .Dense )
777744
778- @keras_parameterized .run_all_keras_modes
779745 def testStripClusteringAndSetOriginalWeightsBack (self ):
780746 """Verifies that we can set_weights onto the stripped model."""
781747 model = keras .Sequential ([
0 commit comments