@@ -43,19 +43,57 @@ def setUp(self):
4343 }
4444
4545 self .x_train = np .array (
46- [[0.0 , 1.0 ], [2.0 , 0.0 ], [0.0 , 3.0 ], [4.0 , 1.0 ], [5.0 , 1.0 ]],
46+ [[0.0 , 1.0 , 2.0 , 3.0 , 4.0 ], [2.0 , 0.0 , 2.0 , 3.0 , 4.0 ], [0.0 , 3.0 , 2.0 , 3.0 , 4.0 ],
47+ [4.0 , 1.0 , 2.0 , 3.0 , 4.0 ], [5.0 , 1.0 , 2.0 , 3.0 , 4.0 ]],
4748 dtype = "float32" ,
4849 )
4950
5051 self .y_train = np .array (
51- [[0.0 , 1.0 ], [1.0 , 0.0 ], [1.0 , 0.0 ], [0.0 , 1.0 ], [0.0 , 1.0 ]],
52+ [[0.0 , 1.0 , 2.0 , 3.0 , 4.0 ], [1.0 , 0.0 , 2.0 , 3.0 , 4.0 ], [1.0 , 0.0 , 2.0 , 3.0 , 4.0 ],
53+ [0.0 , 1.0 , 2.0 , 3.0 , 4.0 ], [0.0 , 1.0 , 2.0 , 3.0 , 4.0 ]],
54+ dtype = "float32" ,
55+ )
56+
57+ self .x_test = np .array (
58+ [[1.0 , 2.0 , 3.0 , 4.0 , 5.0 ], [6.0 , 7.0 , 8.0 , 9.0 , 10.0 ], [1.0 , 2.0 , 3.0 , 4.0 , 5.0 ],
59+ [6.0 , 1.0 , 2.0 , 3.0 , 4.0 ], [9.0 , 1.0 , 0.0 , 3.0 , 0.0 ]],
5260 dtype = "float32" ,
5361 )
5462
5563 def dataset_generator (self ):
5664 for x , y in zip (self .x_train , self .y_train ):
5765 yield np .array ([x ]), np .array ([y ])
5866
67+ def end_to_end_testing (self , original_model , clusters_check = None ):
68+ """Test End to End clustering."""
69+
70+ clustered_model = cluster .cluster_weights (original_model , ** self .params )
71+
72+ clustered_model .compile (
73+ loss = keras .losses .categorical_crossentropy ,
74+ optimizer = "adam" ,
75+ metrics = ["accuracy" ],
76+ )
77+
78+ clustered_model .fit (x = self .dataset_generator (), steps_per_epoch = 1 )
79+ stripped_model = cluster .strip_clustering (clustered_model )
80+ if clusters_check is not None :
81+ clusters_check (stripped_model )
82+
83+ _ , tflite_file = tempfile .mkstemp (".tflite" )
84+ _ , keras_file = tempfile .mkstemp (".h5" )
85+
86+ converter = tf .lite .TFLiteConverter .from_keras_model (stripped_model )
87+ tflite_model = converter .convert ()
88+
89+ with open (tflite_file , "wb" ) as f :
90+ f .write (tflite_model )
91+
92+ self ._verify_tflite (tflite_file , self .x_test )
93+
94+ os .remove (keras_file )
95+ os .remove (tflite_file )
96+
5997 @staticmethod
6098 def _verify_tflite (tflite_file , x_test ):
6199 interpreter = tf .lite .Interpreter (model_path = tflite_file )
@@ -72,8 +110,8 @@ def _verify_tflite(tflite_file, x_test):
72110 def testValuesRemainClusteredAfterTraining (self ):
73111 """Verifies that training a clustered model does not destroy the clusters."""
74112 original_model = keras .Sequential ([
75- layers .Dense (2 , input_shape = (2 ,)),
76- layers .Dense (2 ),
113+ layers .Dense (5 , input_shape = (5 ,)),
114+ layers .Dense (5 ),
77115 ])
78116
79117 clustered_model = cluster .cluster_weights (original_model , ** self .params )
@@ -91,42 +129,95 @@ def testValuesRemainClusteredAfterTraining(self):
91129 self .assertLessEqual (len (unique_weights ), self .params ["number_of_clusters" ])
92130
93131 @keras_parameterized .run_all_keras_modes (always_skip_v1 = True )
94- def testEndToEnd (self ):
95- """Test End to End clustering."""
132+ def testEndToEndSequential (self ):
133+ """Test End to End clustering - sequential model ."""
96134 original_model = keras .Sequential ([
97- layers .Dense (2 , input_shape = (2 ,)),
98- layers .Dense (2 ),
135+ layers .Dense (5 , input_shape = (5 ,)),
136+ layers .Dense (5 ),
99137 ])
100138
101- clustered_model = cluster .cluster_weights (original_model , ** self .params )
139+ def clusters_check (stripped_model ):
140+ # dense layer
141+ weights_as_list = stripped_model .get_weights ()[0 ].reshape (- 1 ,).tolist ()
142+ unique_weights = set (weights_as_list )
143+ self .assertLessEqual (len (unique_weights ), self .params ["number_of_clusters" ])
102144
103- clustered_model .compile (
104- loss = keras .losses .categorical_crossentropy ,
105- optimizer = "adam" ,
106- metrics = ["accuracy" ],
107- )
145+ self .end_to_end_testing (original_model , clusters_check )
108146
109- clustered_model .fit (x = self .dataset_generator (), steps_per_epoch = 1 )
110- stripped_model = cluster .strip_clustering (clustered_model )
147+ @keras_parameterized .run_all_keras_modes (always_skip_v1 = True )
148+ def testEndToEndFunctional (self ):
149+ """Test End to End clustering - functional model."""
150+ inputs = keras .layers .Input (shape = (5 ,))
151+ layer1 = keras .layers .Dense (5 )(inputs )
152+ layer2 = keras .layers .Dense (5 )(layer1 )
153+ original_model = keras .Model (inputs = inputs , outputs = layer2 )
111154
112- _ , tflite_file = tempfile .mkstemp (".tflite" )
113- _ , keras_file = tempfile .mkstemp (".h5" )
155+ def clusters_check (stripped_model ):
156+ # First dense layer
157+ weights_as_list = stripped_model .get_weights ()[0 ].reshape (- 1 ,).tolist ()
158+ unique_weights = set (weights_as_list )
159+ self .assertLessEqual (len (unique_weights ), self .params ["number_of_clusters" ])
114160
115- if not compat .is_v1_apis ():
116- converter = tf .lite .TFLiteConverter .from_keras_model (stripped_model )
117- else :
118- tf .keras .models .save_model (stripped_model , keras_file )
119- converter = tf .lite .TFLiteConverter .from_keras_model_file (keras_file )
161+ self .end_to_end_testing (original_model , clusters_check )
120162
121- tflite_model = converter .convert ()
122- with open (tflite_file , "wb" ) as f :
123- f .write (tflite_model )
163+ @keras_parameterized .run_all_keras_modes (always_skip_v1 = True )
164+ def testEndToEndDeepLayer (self ):
165+ """Test End to End clustering for the model with deep layer."""
166+ internal_model = tf .keras .Sequential ([tf .keras .layers .Dense (5 , input_shape = (5 ,))])
167+ original_model = keras .Sequential ([
168+ internal_model ,
169+ layers .Dense (5 ),
170+ ])
124171
125- self ._verify_tflite (tflite_file , self .x_train )
172+ def clusters_check (stripped_model ):
173+ # inner dense layer
174+ weights_as_list = stripped_model ._layers [1 ]._layers [1 ].trainable_weights [0 ].\
175+ numpy ().flatten ()
176+ unique_weights = set (weights_as_list )
177+ self .assertLessEqual (len (unique_weights ), self .params ["number_of_clusters" ])
126178
127- os .remove (keras_file )
128- os .remove (tflite_file )
179+ # outer dense layer
180+ weights_as_list = stripped_model ._layers [2 ].trainable_weights [0 ].\
181+ numpy ().flatten ()
182+ unique_weights = set (weights_as_list )
183+ self .assertLessEqual (len (unique_weights ), self .params ["number_of_clusters" ])
184+
185+ self .end_to_end_testing (original_model , clusters_check )
186+
187+ @keras_parameterized .run_all_keras_modes (always_skip_v1 = True )
188+ def testEndToEndDeepLayer2 (self ):
189+ """Test End to End clustering for the model with 2 deep layers."""
190+ internal_model = tf .keras .Sequential ([tf .keras .layers .Dense (5 , input_shape = (5 ,))])
191+ intermediate_model = keras .Sequential ([
192+ internal_model ,
193+ layers .Dense (5 ),
194+ ])
195+ original_model = keras .Sequential ([
196+ intermediate_model ,
197+ layers .Dense (5 ),
198+ ])
129199
200+ def clusters_check (stripped_model ):
201+ # first inner dense layer
202+ weights_as_list = stripped_model ._layers [1 ]._layers [1 ].trainable_weights [0 ].\
203+ numpy ().flatten ()
204+ unique_weights = set (weights_as_list )
205+ self .assertLessEqual (len (unique_weights ), self .params ["number_of_clusters" ])
206+
207+ # second inner dense layer
208+ weights_as_list = stripped_model ._layers [1 ]._layers [1 ]._layers [1 ].\
209+ trainable_weights [0 ].\
210+ numpy ().flatten ()
211+ unique_weights = set (weights_as_list )
212+ self .assertLessEqual (len (unique_weights ), self .params ["number_of_clusters" ])
213+
214+ # outer dense layer
215+ weights_as_list = stripped_model ._layers [2 ].trainable_weights [0 ].\
216+ numpy ().flatten ()
217+ unique_weights = set (weights_as_list )
218+ self .assertLessEqual (len (unique_weights ), self .params ["number_of_clusters" ])
219+
220+ self .end_to_end_testing (original_model , clusters_check )
130221
131222if __name__ == "__main__" :
132223 test .main ()
0 commit comments