@@ -43,19 +43,57 @@ def setUp(self):
43
43
}
44
44
45
45
self .x_train = np .array (
46
- [[0.0 , 1.0 ], [2.0 , 0.0 ], [0.0 , 3.0 ], [4.0 , 1.0 ], [5.0 , 1.0 ]],
46
+ [[0.0 , 1.0 , 2.0 , 3.0 , 4.0 ], [2.0 , 0.0 , 2.0 , 3.0 , 4.0 ], [0.0 , 3.0 , 2.0 , 3.0 , 4.0 ],
47
+ [4.0 , 1.0 , 2.0 , 3.0 , 4.0 ], [5.0 , 1.0 , 2.0 , 3.0 , 4.0 ]],
47
48
dtype = "float32" ,
48
49
)
49
50
50
51
self .y_train = np .array (
51
- [[0.0 , 1.0 ], [1.0 , 0.0 ], [1.0 , 0.0 ], [0.0 , 1.0 ], [0.0 , 1.0 ]],
52
+ [[0.0 , 1.0 , 2.0 , 3.0 , 4.0 ], [1.0 , 0.0 , 2.0 , 3.0 , 4.0 ], [1.0 , 0.0 , 2.0 , 3.0 , 4.0 ],
53
+ [0.0 , 1.0 , 2.0 , 3.0 , 4.0 ], [0.0 , 1.0 , 2.0 , 3.0 , 4.0 ]],
54
+ dtype = "float32" ,
55
+ )
56
+
57
+ self .x_test = np .array (
58
+ [[1.0 , 2.0 , 3.0 , 4.0 , 5.0 ], [6.0 , 7.0 , 8.0 , 9.0 , 10.0 ], [1.0 , 2.0 , 3.0 , 4.0 , 5.0 ],
59
+ [6.0 , 1.0 , 2.0 , 3.0 , 4.0 ], [9.0 , 1.0 , 0.0 , 3.0 , 0.0 ]],
52
60
dtype = "float32" ,
53
61
)
54
62
55
63
def dataset_generator (self ):
56
64
for x , y in zip (self .x_train , self .y_train ):
57
65
yield np .array ([x ]), np .array ([y ])
58
66
67
+ def end_to_end_testing (self , original_model , clusters_check = None ):
68
+ """Test End to End clustering."""
69
+
70
+ clustered_model = cluster .cluster_weights (original_model , ** self .params )
71
+
72
+ clustered_model .compile (
73
+ loss = keras .losses .categorical_crossentropy ,
74
+ optimizer = "adam" ,
75
+ metrics = ["accuracy" ],
76
+ )
77
+
78
+ clustered_model .fit (x = self .dataset_generator (), steps_per_epoch = 1 )
79
+ stripped_model = cluster .strip_clustering (clustered_model )
80
+ if clusters_check is not None :
81
+ clusters_check (stripped_model )
82
+
83
+ _ , tflite_file = tempfile .mkstemp (".tflite" )
84
+ _ , keras_file = tempfile .mkstemp (".h5" )
85
+
86
+ converter = tf .lite .TFLiteConverter .from_keras_model (stripped_model )
87
+ tflite_model = converter .convert ()
88
+
89
+ with open (tflite_file , "wb" ) as f :
90
+ f .write (tflite_model )
91
+
92
+ self ._verify_tflite (tflite_file , self .x_test )
93
+
94
+ os .remove (keras_file )
95
+ os .remove (tflite_file )
96
+
59
97
@staticmethod
60
98
def _verify_tflite (tflite_file , x_test ):
61
99
interpreter = tf .lite .Interpreter (model_path = tflite_file )
@@ -72,8 +110,8 @@ def _verify_tflite(tflite_file, x_test):
72
110
def testValuesRemainClusteredAfterTraining (self ):
73
111
"""Verifies that training a clustered model does not destroy the clusters."""
74
112
original_model = keras .Sequential ([
75
- layers .Dense (2 , input_shape = (2 ,)),
76
- layers .Dense (2 ),
113
+ layers .Dense (5 , input_shape = (5 ,)),
114
+ layers .Dense (5 ),
77
115
])
78
116
79
117
clustered_model = cluster .cluster_weights (original_model , ** self .params )
@@ -91,42 +129,95 @@ def testValuesRemainClusteredAfterTraining(self):
91
129
self .assertLessEqual (len (unique_weights ), self .params ["number_of_clusters" ])
92
130
93
131
@keras_parameterized .run_all_keras_modes (always_skip_v1 = True )
94
- def testEndToEnd (self ):
95
- """Test End to End clustering."""
132
+ def testEndToEndSequential (self ):
133
+ """Test End to End clustering - sequential model ."""
96
134
original_model = keras .Sequential ([
97
- layers .Dense (2 , input_shape = (2 ,)),
98
- layers .Dense (2 ),
135
+ layers .Dense (5 , input_shape = (5 ,)),
136
+ layers .Dense (5 ),
99
137
])
100
138
101
- clustered_model = cluster .cluster_weights (original_model , ** self .params )
139
+ def clusters_check (stripped_model ):
140
+ # dense layer
141
+ weights_as_list = stripped_model .get_weights ()[0 ].reshape (- 1 ,).tolist ()
142
+ unique_weights = set (weights_as_list )
143
+ self .assertLessEqual (len (unique_weights ), self .params ["number_of_clusters" ])
102
144
103
- clustered_model .compile (
104
- loss = keras .losses .categorical_crossentropy ,
105
- optimizer = "adam" ,
106
- metrics = ["accuracy" ],
107
- )
145
+ self .end_to_end_testing (original_model , clusters_check )
108
146
109
- clustered_model .fit (x = self .dataset_generator (), steps_per_epoch = 1 )
110
- stripped_model = cluster .strip_clustering (clustered_model )
147
+ @keras_parameterized .run_all_keras_modes (always_skip_v1 = True )
148
+ def testEndToEndFunctional (self ):
149
+ """Test End to End clustering - functional model."""
150
+ inputs = keras .layers .Input (shape = (5 ,))
151
+ layer1 = keras .layers .Dense (5 )(inputs )
152
+ layer2 = keras .layers .Dense (5 )(layer1 )
153
+ original_model = keras .Model (inputs = inputs , outputs = layer2 )
111
154
112
- _ , tflite_file = tempfile .mkstemp (".tflite" )
113
- _ , keras_file = tempfile .mkstemp (".h5" )
155
+ def clusters_check (stripped_model ):
156
+ # First dense layer
157
+ weights_as_list = stripped_model .get_weights ()[0 ].reshape (- 1 ,).tolist ()
158
+ unique_weights = set (weights_as_list )
159
+ self .assertLessEqual (len (unique_weights ), self .params ["number_of_clusters" ])
114
160
115
- if not compat .is_v1_apis ():
116
- converter = tf .lite .TFLiteConverter .from_keras_model (stripped_model )
117
- else :
118
- tf .keras .models .save_model (stripped_model , keras_file )
119
- converter = tf .lite .TFLiteConverter .from_keras_model_file (keras_file )
161
+ self .end_to_end_testing (original_model , clusters_check )
120
162
121
- tflite_model = converter .convert ()
122
- with open (tflite_file , "wb" ) as f :
123
- f .write (tflite_model )
163
+ @keras_parameterized .run_all_keras_modes (always_skip_v1 = True )
164
+ def testEndToEndDeepLayer (self ):
165
+ """Test End to End clustering for the model with deep layer."""
166
+ internal_model = tf .keras .Sequential ([tf .keras .layers .Dense (5 , input_shape = (5 ,))])
167
+ original_model = keras .Sequential ([
168
+ internal_model ,
169
+ layers .Dense (5 ),
170
+ ])
124
171
125
- self ._verify_tflite (tflite_file , self .x_train )
172
+ def clusters_check (stripped_model ):
173
+ # inner dense layer
174
+ weights_as_list = stripped_model ._layers [1 ]._layers [1 ].trainable_weights [0 ].\
175
+ numpy ().flatten ()
176
+ unique_weights = set (weights_as_list )
177
+ self .assertLessEqual (len (unique_weights ), self .params ["number_of_clusters" ])
126
178
127
- os .remove (keras_file )
128
- os .remove (tflite_file )
179
+ # outer dense layer
180
+ weights_as_list = stripped_model ._layers [2 ].trainable_weights [0 ].\
181
+ numpy ().flatten ()
182
+ unique_weights = set (weights_as_list )
183
+ self .assertLessEqual (len (unique_weights ), self .params ["number_of_clusters" ])
184
+
185
+ self .end_to_end_testing (original_model , clusters_check )
186
+
187
+ @keras_parameterized .run_all_keras_modes (always_skip_v1 = True )
188
+ def testEndToEndDeepLayer2 (self ):
189
+ """Test End to End clustering for the model with 2 deep layers."""
190
+ internal_model = tf .keras .Sequential ([tf .keras .layers .Dense (5 , input_shape = (5 ,))])
191
+ intermediate_model = keras .Sequential ([
192
+ internal_model ,
193
+ layers .Dense (5 ),
194
+ ])
195
+ original_model = keras .Sequential ([
196
+ intermediate_model ,
197
+ layers .Dense (5 ),
198
+ ])
129
199
200
+ def clusters_check (stripped_model ):
201
+ # first inner dense layer
202
+ weights_as_list = stripped_model ._layers [1 ]._layers [1 ].trainable_weights [0 ].\
203
+ numpy ().flatten ()
204
+ unique_weights = set (weights_as_list )
205
+ self .assertLessEqual (len (unique_weights ), self .params ["number_of_clusters" ])
206
+
207
+ # second inner dense layer
208
+ weights_as_list = stripped_model ._layers [1 ]._layers [1 ]._layers [1 ].\
209
+ trainable_weights [0 ].\
210
+ numpy ().flatten ()
211
+ unique_weights = set (weights_as_list )
212
+ self .assertLessEqual (len (unique_weights ), self .params ["number_of_clusters" ])
213
+
214
+ # outer dense layer
215
+ weights_as_list = stripped_model ._layers [2 ].trainable_weights [0 ].\
216
+ numpy ().flatten ()
217
+ unique_weights = set (weights_as_list )
218
+ self .assertLessEqual (len (unique_weights ), self .params ["number_of_clusters" ])
219
+
220
+ self .end_to_end_testing (original_model , clusters_check )
130
221
131
222
if __name__ == "__main__" :
132
223
test .main ()
0 commit comments