@@ -164,6 +164,37 @@ def main(unused_argv):
164164 print ('evaluate 1x4 model' )
165165 print (keras_test_utils .eval_mnist_tflite (model_content = tflite_model ))
166166
167+ ##############################################################################
168+ # Train and convert a model with 1x16 block config, and enable post-training
169+ # dynamic range quantization during conversion.
170+ ##############################################################################
171+ pruning_params = {
172+ 'pruning_schedule' :
173+ ConstantSparsity (FLAGS .sparsity , begin_step = 0 , frequency = 100 ),
174+ # TFLite transposes the weight during conversion, so we need to specify
175+ # the block as (16, 1) in the training API.
176+ 'block_size' : (16 , 1 )
177+ }
178+
179+ model = build_layerwise_model (input_shape , ** pruning_params )
180+ model = train (model , x_train , y_train , x_test , y_test )
181+
182+ converter = tf .lite .TFLiteConverter .from_keras_model (model )
183+ converter .optimizations = {
184+ tf .lite .Optimize .DEFAULT , tf .lite .Optimize .EXPERIMENTAL_SPARSITY
185+ }
186+
187+ tflite_model = converter .convert ()
188+ # Check the model is compressed
189+ print ('Compression ratio: ' , len (tflite_model ) / len (tflite_model_dense ))
190+
191+ tflite_model_path = '/tmp/sparse_mnist_%s_1x16.tflite' % FLAGS .sparsity
192+ with open (tflite_model_path , 'wb' ) as f :
193+ f .write (tflite_model )
194+
195+ print ('evaluate 1x16 model' )
196+ print (keras_test_utils .eval_mnist_tflite (model_content = tflite_model ))
197+
167198
168199if __name__ == '__main__' :
169200 absl_app .run (main )
0 commit comments