@@ -128,6 +128,14 @@ def _verify_tflite(tflite_file, x_test):
128128 interpreter .invoke ()
129129 interpreter .get_tensor (output_index )
130130
131+ @staticmethod
132+ def _get_number_of_unique_weights (stripped_model , layer_nr , weight_name ):
133+ layer = stripped_model .layers [layer_nr ]
134+ weight = getattr (layer , weight_name )
135+ weights_as_list = weight .numpy ().flatten ()
136+ nr_of_unique_weights = len (set (weights_as_list ))
137+ return nr_of_unique_weights
138+
131139 @keras_parameterized .run_all_keras_modes
132140 def testValuesRemainClusteredAfterTraining (self ):
133141 """Verifies that training a clustered model does not destroy the clusters."""
@@ -150,73 +158,59 @@ def testValuesRemainClusteredAfterTraining(self):
150158 unique_weights = set (weights_as_list )
151159 self .assertLessEqual (len (unique_weights ), self .params ["number_of_clusters" ])
152160
161+
153162 @keras_parameterized .run_all_keras_modes
154163 def testSparsityIsPreservedDuringTraining (self ):
155- # Set a specific random seed to ensure that we get some null weights to
156- # test sparsity preservation with.
164+ """ Set a specific random seed to ensure that we get some null weights
165+ to test sparsity preservation with."""
157166 tf .random .set_seed (1 )
158-
159- # Verifies that training a clustered model does not destroy the sparsity of
160- # the weights.
167+ # Verifies that training a clustered model with null weights in it
168+ # does not destroy the sparsity of the weights.
161169 original_model = keras .Sequential ([
162170 layers .Dense (5 , input_shape = (5 ,)),
163- layers .Dense ( 5 ),
171+ layers .Flatten ( ),
164172 ])
165-
166- # Using a mininum number of centroids to make it more likely that some
167- # weights will be zero.
173+ # Reset the kernel weights to reflect potential zero drifting of
174+ # the cluster centroids
175+ first_layer_weights = original_model .layers [0 ].get_weights ()
176+ first_layer_weights [0 ][:][0 :2 ] = 0.0
177+ first_layer_weights [0 ][:][3 ] = [- 0.13 , - 0.08 , - 0.05 , 0.005 , 0.13 ]
178+ first_layer_weights [0 ][:][4 ] = [- 0.13 , - 0.08 , - 0.05 , 0.005 , 0.13 ]
179+ original_model .layers [0 ].set_weights (first_layer_weights )
168180 clustering_params = {
169- "number_of_clusters" : 3 ,
181+ "number_of_clusters" : 6 ,
170182 "cluster_centroids_init" : CentroidInitialization .LINEAR ,
171183 "preserve_sparsity" : True
172184 }
173-
174185 clustered_model = experimental_cluster .cluster_weights (
175186 original_model , ** clustering_params )
176-
177187 stripped_model_before_tuning = cluster .strip_clustering (clustered_model )
178- weights_before_tuning = stripped_model_before_tuning .layers [0 ].kernel
179- non_zero_weight_indices_before_tuning = np .nonzero (weights_before_tuning )
180-
188+ nr_of_unique_weights_before = self ._get_number_of_unique_weights (
189+ stripped_model_before_tuning , 0 , 'kernel' )
181190 clustered_model .compile (
182191 loss = keras .losses .categorical_crossentropy ,
183192 optimizer = "adam" ,
184193 metrics = ["accuracy" ],
185194 )
186- clustered_model .fit (x = self .dataset_generator2 (), steps_per_epoch = 1 )
187-
195+ clustered_model .fit (x = self .dataset_generator (), steps_per_epoch = 100 )
188196 stripped_model_after_tuning = cluster .strip_clustering (clustered_model )
189197 weights_after_tuning = stripped_model_after_tuning .layers [0 ].kernel
190- non_zero_weight_indices_after_tuning = np . nonzero ( weights_after_tuning )
191- weights_as_list_after_tuning = weights_after_tuning . numpy (). reshape (
192- - 1 ,). tolist ()
193- unique_weights_after_tuning = set ( weights_as_list_after_tuning )
194-
198+ nr_of_unique_weights_after = self . _get_number_of_unique_weights (
199+ stripped_model_after_tuning , 0 , 'kernel' )
200+ # Check after sparsity-aware clustering, despite zero centroid can drift,
201+ # the final number of unique weights remains the same
202+ self . assertEqual ( nr_of_unique_weights_before , nr_of_unique_weights_after )
195203 # Check that the null weights stayed the same before and after tuning.
204+ # There might be new weights that become zeros but sparsity-aware
205+ # clustering preserves the original null weights in the original positions
206+ # of the weight array
196207 self .assertTrue (
197- np .array_equal (non_zero_weight_indices_before_tuning ,
198- non_zero_weight_indices_after_tuning ))
199-
208+ np .array_equal (first_layer_weights [0 ][:][0 :2 ],
209+ weights_after_tuning [:][0 :2 ]))
200210 # Check that the number of unique weights matches the number of clusters.
201211 self .assertLessEqual (
202- len (unique_weights_after_tuning ), self .params ["number_of_clusters" ])
203-
204- @keras_parameterized .run_all_keras_modes (always_skip_v1 = True )
205- def testEndToEndSequential (self ):
206- """Test End to End clustering - sequential model."""
207- original_model = keras .Sequential ([
208- layers .Dense (5 , input_shape = (5 ,)),
209- layers .Dense (5 ),
210- ])
211-
212- def clusters_check (stripped_model ):
213- # dense layer
214- weights_as_list = stripped_model .get_weights ()[0 ].reshape (- 1 ,).tolist ()
215- unique_weights = set (weights_as_list )
216- self .assertLessEqual (
217- len (unique_weights ), self .params ["number_of_clusters" ])
218-
219- self .end_to_end_testing (original_model , clusters_check )
212+ nr_of_unique_weights_after ,
213+ clustering_params ["number_of_clusters" ])
220214
221215 @keras_parameterized .run_all_keras_modes (always_skip_v1 = True )
222216 def testEndToEndFunctional (self ):
0 commit comments