@@ -164,6 +164,37 @@ def main(unused_argv):
164
164
print ('evaluate 1x4 model' )
165
165
print (keras_test_utils .eval_mnist_tflite (model_content = tflite_model ))
166
166
167
+ ##############################################################################
168
+ # Train and convert a model with 1x16 block config, and enable post-training
169
+ # dynamic range quantization during conversion.
170
+ ##############################################################################
171
+ pruning_params = {
172
+ 'pruning_schedule' :
173
+ ConstantSparsity (FLAGS .sparsity , begin_step = 0 , frequency = 100 ),
174
+ # TFLite transposes the weight during conversion, so we need to specify
175
+ # the block as (16, 1) in the training API.
176
+ 'block_size' : (16 , 1 )
177
+ }
178
+
179
+ model = build_layerwise_model (input_shape , ** pruning_params )
180
+ model = train (model , x_train , y_train , x_test , y_test )
181
+
182
+ converter = tf .lite .TFLiteConverter .from_keras_model (model )
183
+ converter .optimizations = {
184
+ tf .lite .Optimize .DEFAULT , tf .lite .Optimize .EXPERIMENTAL_SPARSITY
185
+ }
186
+
187
+ tflite_model = converter .convert ()
188
+ # Check the model is compressed
189
+ print ('Compression ratio: ' , len (tflite_model ) / len (tflite_model_dense ))
190
+
191
+ tflite_model_path = '/tmp/sparse_mnist_%s_1x16.tflite' % FLAGS .sparsity
192
+ with open (tflite_model_path , 'wb' ) as f :
193
+ f .write (tflite_model )
194
+
195
+ print ('evaluate 1x16 model' )
196
+ print (keras_test_utils .eval_mnist_tflite (model_content = tflite_model ))
197
+
167
198
168
199
if __name__ == '__main__' :
169
200
absl_app .run (main )
0 commit comments