1616
1717import json
1818
19+ import tempfile
20+
21+ import os
1922from absl .testing import parameterized
2023import tensorflow as tf
2124
@@ -253,6 +256,76 @@ def testClusterKerasCustomLayer(self):
253256 with self .assertRaises (ValueError ):
254257 cluster_wrapper .ClusterWeights (keras_custom_layer , ** self .params )
255258
259+ def testStripClusteringSequentialModelWithKernelRegularizer (self ):
260+ """Verifies that stripping the clustering wrappers from a sequential model produces the expected config."""
261+ model = keras .Sequential ([
262+ layers .Dense (10 , input_shape = (10 ,)),
263+ layers .Dense (10 , kernel_regularizer = tf .keras .regularizers .L1 (0.01 )),
264+ ])
265+ clustered_model = cluster .cluster_weights (model , ** self .params )
266+ stripped_model = cluster .strip_clustering (clustered_model )
267+ # check that kernel regularizer is present in the second dense layer
268+ self .assertIsNotNone (stripped_model .layers [1 ].kernel_regularizer )
269+ with tempfile .TemporaryDirectory () as tmp_dir_name :
270+ keras_file = os .path .join (tmp_dir_name , 'cluster_test' )
271+ stripped_model .save (keras_file , save_traces = True )
272+
273+ def testStripClusteringSequentialModelWithBiasRegularizer (self ):
274+ """Verifies that stripping the clustering wrappers from a sequential model produces the expected config."""
275+ model = keras .Sequential ([
276+ layers .Dense (10 , input_shape = (10 ,)),
277+ layers .Dense (10 , bias_regularizer = tf .keras .regularizers .L1 (0.01 )),
278+ ])
279+ clustered_model = cluster .cluster_weights (model , ** self .params )
280+ stripped_model = cluster .strip_clustering (clustered_model )
281+ # check that kernel regularizer is present in the second dense layer
282+ self .assertIsNotNone (stripped_model .layers [1 ].bias_regularizer )
283+ with tempfile .TemporaryDirectory () as tmp_dir_name :
284+ keras_file = os .path .join (tmp_dir_name , 'cluster_test' )
285+ stripped_model .save (keras_file , save_traces = True )
286+
287+ def testStripClusteringSequentialModelWithActivityRegularizer (self ):
288+ """Verifies that stripping the clustering wrappers from a sequential model produces the expected config."""
289+ model = keras .Sequential ([
290+ layers .Dense (10 , input_shape = (10 ,)),
291+ layers .Dense (10 , activity_regularizer = tf .keras .regularizers .L1 (0.01 )),
292+ ])
293+ clustered_model = cluster .cluster_weights (model , ** self .params )
294+ stripped_model = cluster .strip_clustering (clustered_model )
295+ # check that kernel regularizer is present in the second dense layer
296+ self .assertIsNotNone (stripped_model .layers [1 ].activity_regularizer )
297+ with tempfile .TemporaryDirectory () as tmp_dir_name :
298+ keras_file = os .path .join (tmp_dir_name , 'cluster_test' )
299+ stripped_model .save (keras_file , save_traces = True )
300+
301+ def testStripClusteringSequentialModelWithKernelConstraint (self ):
302+ """Verifies that stripping the clustering wrappers from a sequential model produces the expected config."""
303+ model = keras .Sequential ([
304+ layers .Dense (10 , input_shape = (10 ,)),
305+ layers .Dense (10 , kernel_constraint = tf .keras .constraints .max_norm (2. )),
306+ ])
307+ clustered_model = cluster .cluster_weights (model , ** self .params )
308+ stripped_model = cluster .strip_clustering (clustered_model )
309+ # check that kernel regularizer is present in the second dense layer
310+ self .assertIsNotNone (stripped_model .layers [1 ].kernel_constraint )
311+ with tempfile .TemporaryDirectory () as tmp_dir_name :
312+ keras_file = os .path .join (tmp_dir_name , 'cluster_test' )
313+ stripped_model .save (keras_file , save_traces = True )
314+
315+ def testStripClusteringSequentialModelWithBiasConstraint (self ):
316+ """Verifies that stripping the clustering wrappers from a sequential model produces the expected config."""
317+ model = keras .Sequential ([
318+ layers .Dense (10 , input_shape = (10 ,)),
319+ layers .Dense (10 , bias_constraint = tf .keras .constraints .max_norm (2. )),
320+ ])
321+ clustered_model = cluster .cluster_weights (model , ** self .params )
322+ stripped_model = cluster .strip_clustering (clustered_model )
323+ # check that kernel regularizer is present in the second dense layer
324+ self .assertIsNotNone (stripped_model .layers [1 ].bias_constraint )
325+ with tempfile .TemporaryDirectory () as tmp_dir_name :
326+ keras_file = os .path .join (tmp_dir_name , 'cluster_test' )
327+ stripped_model .save (keras_file , save_traces = True )
328+
256329 def testClusterMyClusterableLayer (self ):
257330 # we have weights to cluster.
258331 layer = self .clusterable_layer
@@ -539,7 +612,7 @@ def testClusterSubclassModelAsSubmodel(self):
539612 def testStripClusteringSequentialModel (self ):
540613 """Verifies that stripping the clustering wrappers from a sequential model produces the expected config."""
541614 model = keras .Sequential ([
542- layers .Dense (10 ),
615+ layers .Dense (10 , input_shape = ( 5 ,) ),
543616 layers .Dense (10 ),
544617 ])
545618
@@ -582,7 +655,7 @@ def testClusterWeightsStrippedWeights(self):
582655
583656 @keras_parameterized .run_all_keras_modes
584657 def testStrippedKernel (self ):
585- """Verifies that stripping the clustering wrappers from a functional model restores the layers kernel and the layers weight array to the new clustered weight value ."""
658+ """Verifies that stripping the clustering wrappers from a functional model restores the layers kernel and the layers weight array to the new clustered weight value."""
586659 i1 = keras .Input (shape = (1 , 1 , 1 ))
587660 x1 = layers .Conv2D (1 , 1 )(i1 )
588661 outputs = x1
@@ -596,8 +669,7 @@ def testStrippedKernel(self):
596669
597670 self .assertEqual (self ._count_clustered_layers (stripped_model ), 0 )
598671 self .assertIsNot (stripped_conv2d_layer .kernel , clustered_kernel )
599- self .assertEqual (stripped_conv2d_layer .kernel ,
600- stripped_conv2d_layer .weights [0 ])
672+ self .assertIn (stripped_conv2d_layer .kernel , stripped_conv2d_layer .weights )
601673
602674 @keras_parameterized .run_all_keras_modes
603675 def testStripSelectivelyClusteredFunctionalModel (self ):
@@ -628,5 +700,23 @@ def testStripSelectivelyClusteredSequentialModel(self):
628700 self .assertEqual (self ._count_clustered_layers (stripped_model ), 0 )
629701 self .assertIsInstance (stripped_model .layers [0 ], layers .Dense )
630702
703+ @keras_parameterized .run_all_keras_modes
704+ def testStripClusteringAndSetOriginalWeightsBack (self ):
705+ """Verifies that we can set_weights onto the stripped model."""
706+ model = keras .Sequential ([
707+ layers .Dense (10 , input_shape = (5 ,)),
708+ layers .Dense (10 ),
709+ ])
710+
711+ # Save original weights
712+ original_weights = model .get_weights ()
713+
714+ # Cluster and strip
715+ clustered_model = cluster .cluster_weights (model , ** self .params )
716+ stripped_model = cluster .strip_clustering (clustered_model )
717+
718+ # Set back original weights onto the strip model
719+ stripped_model .set_weights (original_weights )
720+
631721if __name__ == '__main__' :
632722 test .main ()
0 commit comments