@@ -85,6 +85,11 @@ def dataset_generator2(self):
85
85
for x , y in zip (self .x_train2 , self .y_train2 ):
86
86
yield np .array ([x ]), np .array ([y ])
87
87
88
+ def _batch (self , dims , batch_size ):
89
+ if dims [0 ] is None :
90
+ dims [0 ] = batch_size
91
+ return dims
92
+
88
93
def end_to_end_testing (self , original_model , clusters_check = None ):
89
94
"""Test End to End clustering."""
90
95
@@ -225,6 +230,80 @@ def testSparsityIsPreservedDuringTraining(self):
225
230
nr_of_unique_weights_after ,
226
231
clustering_params ["number_of_clusters" ])
227
232
233
+ @keras_parameterized .run_all_keras_modes (always_skip_v1 = True )
234
+ def testEndToEndSequential (self ):
235
+ """Test End to End clustering - sequential model."""
236
+ original_model = keras .Sequential ([
237
+ layers .Dense (5 , input_shape = (5 ,)),
238
+ layers .Dense (5 ),
239
+ ])
240
+
241
+ def clusters_check (stripped_model ):
242
+ # dense layer
243
+ weights_as_list = stripped_model .get_weights ()[0 ].reshape (- 1 ,).tolist ()
244
+ unique_weights = set (weights_as_list )
245
+ self .assertLessEqual (
246
+ len (unique_weights ), self .params ["number_of_clusters" ])
247
+
248
+ self .end_to_end_testing (original_model , clusters_check )
249
+
250
+ @keras_parameterized .run_all_keras_modes (always_skip_v1 = True )
251
+ def testEndToEndConv1DAndConv1DTranspose (self ):
252
+ """Test End to End clustering - model with Conv1D and Conv1DTranspose."""
253
+ inp = layers .Input (batch_shape = (1 , 16 ))
254
+ x = layers .Conv1D (
255
+ 10 , 16 , 4 , padding = "valid" , use_bias = False )(
256
+ tf .expand_dims (inp , axis = - 1 ))
257
+ y = layers .Conv1DTranspose (1 , 16 , 4 , padding = "valid" , use_bias = False )(x )
258
+ model = keras .models .Model (inputs = inp , outputs = [y ])
259
+
260
+ def apply_clustering (layer ):
261
+ if isinstance (layer , keras .layers .Conv1D ) or isinstance (
262
+ layer , keras .layers .Conv1DTranspose ):
263
+ return cluster .cluster_weights (layer , ** self .params )
264
+ return layer
265
+
266
+ model_to_cluster = keras .models .clone_model (
267
+ model ,
268
+ clone_function = apply_clustering ,
269
+ )
270
+
271
+ model_to_cluster .compile (
272
+ loss = keras .losses .categorical_crossentropy ,
273
+ optimizer = "adam" ,
274
+ metrics = ["accuracy" ]
275
+ )
276
+ model_to_cluster .fit (
277
+ np .random .randn (* self ._batch (model .input .get_shape ().as_list (), 16 )),
278
+ np .random .randn (* self ._batch (model .output .get_shape ().as_list (), 16 )),
279
+ steps_per_epoch = 1 )
280
+ clustered_model = cluster .strip_clustering (model_to_cluster )
281
+
282
+ def do_checks (layer , layer_name ):
283
+ self .assertEqual (layer .name , layer_name )
284
+ unique_weights = np .unique (layer .weights [0 ].numpy ().flatten ())
285
+ self .assertLessEqual (
286
+ len (unique_weights ), self .params ["number_of_clusters" ])
287
+
288
+ do_checks (clustered_model .layers [2 ], "conv1d" )
289
+ do_checks (clustered_model .layers [3 ], "conv1d_transpose" )
290
+
291
+ def testStripClusteringSequentialModelWithRegulariser (self ):
292
+ """Verifies that stripping the clustering wrappers from a sequential model produces the expected config."""
293
+ original_model = keras .Sequential ([
294
+ layers .Dense (5 , input_shape = (5 ,)),
295
+ layers .Dense (5 , kernel_regularizer = tf .keras .regularizers .L1 (0.01 )),
296
+ ])
297
+
298
+ def clusters_check (stripped_model ):
299
+ # dense layer
300
+ weights_as_list = stripped_model .get_weights ()[0 ].reshape (- 1 ,).tolist ()
301
+ unique_weights = set (weights_as_list )
302
+ self .assertLessEqual (
303
+ len (unique_weights ), self .params ["number_of_clusters" ])
304
+
305
+ self .end_to_end_testing (original_model , clusters_check )
306
+
228
307
@keras_parameterized .run_all_keras_modes (always_skip_v1 = True )
229
308
def testEndToEndFunctional (self ):
230
309
"""Test End to End clustering - functional model."""
0 commit comments