@@ -128,6 +128,14 @@ def _verify_tflite(tflite_file, x_test):
128
128
interpreter .invoke ()
129
129
interpreter .get_tensor (output_index )
130
130
131
+ @staticmethod
132
+ def _get_number_of_unique_weights (stripped_model , layer_nr , weight_name ):
133
+ layer = stripped_model .layers [layer_nr ]
134
+ weight = getattr (layer , weight_name )
135
+ weights_as_list = weight .numpy ().flatten ()
136
+ nr_of_unique_weights = len (set (weights_as_list ))
137
+ return nr_of_unique_weights
138
+
131
139
@keras_parameterized .run_all_keras_modes
132
140
def testValuesRemainClusteredAfterTraining (self ):
133
141
"""Verifies that training a clustered model does not destroy the clusters."""
@@ -150,73 +158,59 @@ def testValuesRemainClusteredAfterTraining(self):
150
158
unique_weights = set (weights_as_list )
151
159
self .assertLessEqual (len (unique_weights ), self .params ["number_of_clusters" ])
152
160
161
+
153
162
@keras_parameterized .run_all_keras_modes
154
163
def testSparsityIsPreservedDuringTraining (self ):
155
- # Set a specific random seed to ensure that we get some null weights to
156
- # test sparsity preservation with.
164
+ """ Set a specific random seed to ensure that we get some null weights
165
+ to test sparsity preservation with."""
157
166
tf .random .set_seed (1 )
158
-
159
- # Verifies that training a clustered model does not destroy the sparsity of
160
- # the weights.
167
+ # Verifies that training a clustered model with null weights in it
168
+ # does not destroy the sparsity of the weights.
161
169
original_model = keras .Sequential ([
162
170
layers .Dense (5 , input_shape = (5 ,)),
163
- layers .Dense ( 5 ),
171
+ layers .Flatten ( ),
164
172
])
165
-
166
- # Using a mininum number of centroids to make it more likely that some
167
- # weights will be zero.
173
+ # Reset the kernel weights to reflect potential zero drifting of
174
+ # the cluster centroids
175
+ first_layer_weights = original_model .layers [0 ].get_weights ()
176
+ first_layer_weights [0 ][:][0 :2 ] = 0.0
177
+ first_layer_weights [0 ][:][3 ] = [- 0.13 , - 0.08 , - 0.05 , 0.005 , 0.13 ]
178
+ first_layer_weights [0 ][:][4 ] = [- 0.13 , - 0.08 , - 0.05 , 0.005 , 0.13 ]
179
+ original_model .layers [0 ].set_weights (first_layer_weights )
168
180
clustering_params = {
169
- "number_of_clusters" : 3 ,
181
+ "number_of_clusters" : 6 ,
170
182
"cluster_centroids_init" : CentroidInitialization .LINEAR ,
171
183
"preserve_sparsity" : True
172
184
}
173
-
174
185
clustered_model = experimental_cluster .cluster_weights (
175
186
original_model , ** clustering_params )
176
-
177
187
stripped_model_before_tuning = cluster .strip_clustering (clustered_model )
178
- weights_before_tuning = stripped_model_before_tuning .layers [0 ].kernel
179
- non_zero_weight_indices_before_tuning = np .nonzero (weights_before_tuning )
180
-
188
+ nr_of_unique_weights_before = self ._get_number_of_unique_weights (
189
+ stripped_model_before_tuning , 0 , 'kernel' )
181
190
clustered_model .compile (
182
191
loss = keras .losses .categorical_crossentropy ,
183
192
optimizer = "adam" ,
184
193
metrics = ["accuracy" ],
185
194
)
186
- clustered_model .fit (x = self .dataset_generator2 (), steps_per_epoch = 1 )
187
-
195
+ clustered_model .fit (x = self .dataset_generator (), steps_per_epoch = 100 )
188
196
stripped_model_after_tuning = cluster .strip_clustering (clustered_model )
189
197
weights_after_tuning = stripped_model_after_tuning .layers [0 ].kernel
190
- non_zero_weight_indices_after_tuning = np . nonzero ( weights_after_tuning )
191
- weights_as_list_after_tuning = weights_after_tuning . numpy (). reshape (
192
- - 1 ,). tolist ()
193
- unique_weights_after_tuning = set ( weights_as_list_after_tuning )
194
-
198
+ nr_of_unique_weights_after = self . _get_number_of_unique_weights (
199
+ stripped_model_after_tuning , 0 , 'kernel' )
200
+ # Check after sparsity-aware clustering, despite zero centroid can drift,
201
+ # the final number of unique weights remains the same
202
+ self . assertEqual ( nr_of_unique_weights_before , nr_of_unique_weights_after )
195
203
# Check that the null weights stayed the same before and after tuning.
204
+ # There might be new weights that become zeros but sparsity-aware
205
+ # clustering preserves the original null weights in the original positions
206
+ # of the weight array
196
207
self .assertTrue (
197
- np .array_equal (non_zero_weight_indices_before_tuning ,
198
- non_zero_weight_indices_after_tuning ))
199
-
208
+ np .array_equal (first_layer_weights [0 ][:][0 :2 ],
209
+ weights_after_tuning [:][0 :2 ]))
200
210
# Check that the number of unique weights matches the number of clusters.
201
211
self .assertLessEqual (
202
- len (unique_weights_after_tuning ), self .params ["number_of_clusters" ])
203
-
204
- @keras_parameterized .run_all_keras_modes (always_skip_v1 = True )
205
- def testEndToEndSequential (self ):
206
- """Test End to End clustering - sequential model."""
207
- original_model = keras .Sequential ([
208
- layers .Dense (5 , input_shape = (5 ,)),
209
- layers .Dense (5 ),
210
- ])
211
-
212
- def clusters_check (stripped_model ):
213
- # dense layer
214
- weights_as_list = stripped_model .get_weights ()[0 ].reshape (- 1 ,).tolist ()
215
- unique_weights = set (weights_as_list )
216
- self .assertLessEqual (
217
- len (unique_weights ), self .params ["number_of_clusters" ])
218
-
219
- self .end_to_end_testing (original_model , clusters_check )
212
+ nr_of_unique_weights_after ,
213
+ clustering_params ["number_of_clusters" ])
220
214
221
215
@keras_parameterized .run_all_keras_modes (always_skip_v1 = True )
222
216
def testEndToEndFunctional (self ):
0 commit comments