@@ -538,23 +538,26 @@ def testClusterStackedRNNCells(self):
538538 expected_unique_weights = self .params_clustering ["number_of_clusters" ],
539539 )
540540
541+
541542class ClusterMHAIntegrationTest (tf .test .TestCase , parameterized .TestCase ):
542543 """Integration tests for clustering MHA layer."""
543544
544545 def setUp (self ):
546+ super (ClusterMHAIntegrationTest , self ).setUp ()
545547 self .x_train = np .random .uniform (size = (500 , 32 , 32 ))
546548 self .y_train = np .random .randint (low = 0 , high = 1024 , size = (500 ,))
547549
548550 self .nr_of_clusters = 16
549551 self .params_clustering = {
550- "number_of_clusters" : self .nr_of_clusters ,
551- "cluster_centroids_init" : CentroidInitialization .KMEANS_PLUS_PLUS ,
552+ "number_of_clusters" : self .nr_of_clusters ,
553+ "cluster_centroids_init" : CentroidInitialization .KMEANS_PLUS_PLUS ,
552554 }
553555
554556 def _get_model (self ):
555557 """Returns functional model with MHA layer."""
556- inp = tf .keras .layers .Input (shape = (32 ,32 ), batch_size = 100 )
557- x = tf .keras .layers .MultiHeadAttention (num_heads = 2 , key_dim = 16 )(query = inp , value = inp )
558+ inp = tf .keras .layers .Input (shape = (32 , 32 ), batch_size = 100 )
559+ x = tf .keras .layers .MultiHeadAttention (num_heads = 2 , key_dim = 16 )(
560+ query = inp , value = inp )
558561 out = tf .keras .layers .Flatten ()(x )
559562 model = tf .keras .Model (inputs = inp , outputs = out )
560563 return model
@@ -566,18 +569,70 @@ def testMHA(self):
566569 clustered_model = cluster .cluster_weights (model , ** self .params_clustering )
567570
568571 clustered_model .compile (
569- optimizer = tf .keras .optimizers .Adam (learning_rate = 1e-4 ),
570- loss = tf .keras .losses .SparseCategoricalCrossentropy (from_logits = True ),
571- metrics = [tf .keras .metrics .SparseCategoricalAccuracy (name = 'accuracy' )])
572- clustered_model .fit (self .x_train , self .y_train , epochs = 1 , batch_size = 100 , verbose = 1 )
572+ optimizer = tf .keras .optimizers .Adam (learning_rate = 1e-4 ),
573+ loss = tf .keras .losses .SparseCategoricalCrossentropy (from_logits = True ),
574+ metrics = [tf .keras .metrics .SparseCategoricalAccuracy (name = "accuracy" )])
575+ clustered_model .fit (
576+ self .x_train , self .y_train , epochs = 1 , batch_size = 100 , verbose = 1 )
573577
574578 stripped_model = cluster .strip_clustering (clustered_model )
575579
576- layerMHA = stripped_model .layers [1 ]
577- for weight in layerMHA .weights :
578- if ' kernel' in weight .name :
580+ layer_mha = stripped_model .layers [1 ]
581+ for weight in layer_mha .weights :
582+ if " kernel" in weight .name :
579583 nr_unique_weights = len (np .unique (weight .numpy ()))
580584 assert nr_unique_weights == self .nr_of_clusters
581585
586+
587+ class ClusterPerChannelIntegrationTest (tf .test .TestCase ,
588+ parameterized .TestCase ):
589+ """Integration tests for per-channel clustering of Conv2D layer."""
590+
591+ def setUp (self ):
592+ super (ClusterPerChannelIntegrationTest , self ).setUp ()
593+ self .x_train = np .random .uniform (size = (500 , 32 , 32 ))
594+ self .y_train = np .random .randint (low = 0 , high = 1024 , size = (500 ,))
595+
596+ self .nr_of_clusters = 4
597+ self .num_channels = 12
598+ self .params_clustering = {
599+ "number_of_clusters" : self .nr_of_clusters ,
600+ "cluster_centroids_init" : CentroidInitialization .KMEANS_PLUS_PLUS ,
601+ "cluster_per_channel" : True
602+ }
603+
604+ def _get_model (self ):
605+ """Returns functional model with Conv2D layer."""
606+ inp = tf .keras .layers .Input (shape = (32 , 32 ), batch_size = 100 )
607+ x = tf .keras .layers .Reshape ((32 , 32 , 1 ))(inp )
608+ x = tf .keras .layers .Conv2D (
609+ filters = self .num_channels , kernel_size = (3 , 3 ),
610+ activation = "relu" )(x )
611+ x = tf .keras .layers .MaxPool2D (2 , 2 )(x )
612+ out = tf .keras .layers .Flatten ()(x )
613+ model = tf .keras .Model (inputs = inp , outputs = out )
614+ return model
615+
616+ @keras_parameterized .run_all_keras_modes
617+ def testPerChannel (self ):
618+ model = self ._get_model ()
619+
620+ clustered_model = cluster .cluster_weights (model , ** self .params_clustering )
621+
622+ clustered_model .compile (
623+ optimizer = tf .keras .optimizers .Adam (learning_rate = 1e-4 ),
624+ loss = tf .keras .losses .SparseCategoricalCrossentropy (from_logits = True ),
625+ metrics = [tf .keras .metrics .SparseCategoricalAccuracy (name = "accuracy" )])
626+ clustered_model .fit (
627+ self .x_train , self .y_train , epochs = 1 , batch_size = 100 , verbose = 1 )
628+
629+ stripped_model = cluster .strip_clustering (clustered_model )
630+
631+ layer_conv2d = stripped_model .layers [2 ]
632+ for weight in layer_conv2d .weights :
633+ if "kernel" in weight .name :
634+ nr_unique_weights = len (np .unique (weight .numpy ()))
635+ assert nr_unique_weights == self .nr_of_clusters * self .num_channels
636+
582637if __name__ == "__main__" :
583638 test .main ()
0 commit comments