1616
1717import json
1818
19+ import tempfile
20+
21+ import os
1922from absl .testing import parameterized
2023import tensorflow as tf
2124
@@ -282,6 +285,76 @@ def testClusterKerasCustomLayer(self):
282285 with self .assertRaises (ValueError ):
283286 cluster_wrapper .ClusterWeights (keras_custom_layer , ** self .params )
284287
288+ def testStripClusteringSequentialModelWithKernelRegularizer (self ):
289+ """Verifies that stripping the clustering wrappers from a sequential model produces the expected config."""
290+ model = keras .Sequential ([
291+ layers .Dense (10 , input_shape = (10 ,)),
292+ layers .Dense (10 , kernel_regularizer = tf .keras .regularizers .L1 (0.01 )),
293+ ])
294+ clustered_model = cluster .cluster_weights (model , ** self .params )
295+ stripped_model = cluster .strip_clustering (clustered_model )
296+ # check that kernel regularizer is present in the second dense layer
297+ self .assertIsNotNone (stripped_model .layers [1 ].kernel_regularizer )
298+ with tempfile .TemporaryDirectory () as tmp_dir_name :
299+ keras_file = os .path .join (tmp_dir_name , 'cluster_test' )
300+ stripped_model .save (keras_file , save_traces = True )
301+
302+ def testStripClusteringSequentialModelWithBiasRegularizer (self ):
303+ """Verifies that stripping the clustering wrappers from a sequential model produces the expected config."""
304+ model = keras .Sequential ([
305+ layers .Dense (10 , input_shape = (10 ,)),
306+ layers .Dense (10 , bias_regularizer = tf .keras .regularizers .L1 (0.01 )),
307+ ])
308+ clustered_model = cluster .cluster_weights (model , ** self .params )
309+ stripped_model = cluster .strip_clustering (clustered_model )
310+ # check that kernel regularizer is present in the second dense layer
311+ self .assertIsNotNone (stripped_model .layers [1 ].bias_regularizer )
312+ with tempfile .TemporaryDirectory () as tmp_dir_name :
313+ keras_file = os .path .join (tmp_dir_name , 'cluster_test' )
314+ stripped_model .save (keras_file , save_traces = True )
315+
316+ def testStripClusteringSequentialModelWithActivityRegularizer (self ):
317+ """Verifies that stripping the clustering wrappers from a sequential model produces the expected config."""
318+ model = keras .Sequential ([
319+ layers .Dense (10 , input_shape = (10 ,)),
320+ layers .Dense (10 , activity_regularizer = tf .keras .regularizers .L1 (0.01 )),
321+ ])
322+ clustered_model = cluster .cluster_weights (model , ** self .params )
323+ stripped_model = cluster .strip_clustering (clustered_model )
324+ # check that kernel regularizer is present in the second dense layer
325+ self .assertIsNotNone (stripped_model .layers [1 ].activity_regularizer )
326+ with tempfile .TemporaryDirectory () as tmp_dir_name :
327+ keras_file = os .path .join (tmp_dir_name , 'cluster_test' )
328+ stripped_model .save (keras_file , save_traces = True )
329+
330+ def testStripClusteringSequentialModelWithKernelConstraint (self ):
331+ """Verifies that stripping the clustering wrappers from a sequential model produces the expected config."""
332+ model = keras .Sequential ([
333+ layers .Dense (10 , input_shape = (10 ,)),
334+ layers .Dense (10 , kernel_constraint = tf .keras .constraints .max_norm (2. )),
335+ ])
336+ clustered_model = cluster .cluster_weights (model , ** self .params )
337+ stripped_model = cluster .strip_clustering (clustered_model )
338+ # check that kernel regularizer is present in the second dense layer
339+ self .assertIsNotNone (stripped_model .layers [1 ].kernel_constraint )
340+ with tempfile .TemporaryDirectory () as tmp_dir_name :
341+ keras_file = os .path .join (tmp_dir_name , 'cluster_test' )
342+ stripped_model .save (keras_file , save_traces = True )
343+
344+ def testStripClusteringSequentialModelWithBiasConstraint (self ):
345+ """Verifies that stripping the clustering wrappers from a sequential model produces the expected config."""
346+ model = keras .Sequential ([
347+ layers .Dense (10 , input_shape = (10 ,)),
348+ layers .Dense (10 , bias_constraint = tf .keras .constraints .max_norm (2. )),
349+ ])
350+ clustered_model = cluster .cluster_weights (model , ** self .params )
351+ stripped_model = cluster .strip_clustering (clustered_model )
352+ # check that kernel regularizer is present in the second dense layer
353+ self .assertIsNotNone (stripped_model .layers [1 ].bias_constraint )
354+ with tempfile .TemporaryDirectory () as tmp_dir_name :
355+ keras_file = os .path .join (tmp_dir_name , 'cluster_test' )
356+ stripped_model .save (keras_file , save_traces = True )
357+
285358 def testClusterMyClusterableLayer (self ):
286359 # we have weights to cluster.
287360 layer = self .clusterable_layer
@@ -567,7 +640,7 @@ def testClusterSubclassModelAsSubmodel(self):
567640 def testStripClusteringSequentialModel (self ):
568641 """Verifies that stripping the clustering wrappers from a sequential model produces the expected config."""
569642 model = keras .Sequential ([
570- layers .Dense (10 ),
643+ layers .Dense (10 , input_shape = ( 5 ,) ),
571644 layers .Dense (10 ),
572645 ])
573646
@@ -610,7 +683,7 @@ def testClusterWeightsStrippedWeights(self):
610683
611684 @keras_parameterized .run_all_keras_modes
612685 def testStrippedKernel (self ):
613- """Verifies that stripping the clustering wrappers from a functional model restores the layers kernel and the layers weight array to the new clustered weight value ."""
686+ """Verifies that stripping the clustering wrappers from a functional model restores the layers kernel and the layers weight array to the new clustered weight value."""
614687 i1 = keras .Input (shape = (1 , 1 , 1 ))
615688 x1 = layers .Conv2D (1 , 1 )(i1 )
616689 outputs = x1
@@ -624,8 +697,7 @@ def testStrippedKernel(self):
624697
625698 self .assertEqual (self ._count_clustered_layers (stripped_model ), 0 )
626699 self .assertIsNot (stripped_conv2d_layer .kernel , clustered_kernel )
627- self .assertEqual (stripped_conv2d_layer .kernel ,
628- stripped_conv2d_layer .weights [0 ])
700+ self .assertIn (stripped_conv2d_layer .kernel , stripped_conv2d_layer .weights )
629701
630702 @keras_parameterized .run_all_keras_modes
631703 def testStripSelectivelyClusteredFunctionalModel (self ):
@@ -656,5 +728,23 @@ def testStripSelectivelyClusteredSequentialModel(self):
656728 self .assertEqual (self ._count_clustered_layers (stripped_model ), 0 )
657729 self .assertIsInstance (stripped_model .layers [0 ], layers .Dense )
658730
731+ @keras_parameterized .run_all_keras_modes
732+ def testStripClusteringAndSetOriginalWeightsBack (self ):
733+ """Verifies that we can set_weights onto the stripped model."""
734+ model = keras .Sequential ([
735+ layers .Dense (10 , input_shape = (5 ,)),
736+ layers .Dense (10 ),
737+ ])
738+
739+ # Save original weights
740+ original_weights = model .get_weights ()
741+
742+ # Cluster and strip
743+ clustered_model = cluster .cluster_weights (model , ** self .params )
744+ stripped_model = cluster .strip_clustering (clustered_model )
745+
746+ # Set back original weights onto the strip model
747+ stripped_model .set_weights (original_weights )
748+
659749if __name__ == '__main__' :
660750 test .main ()
0 commit comments