Skip to content

Commit 392b0b3

Browse files
liyunlu0618tensorflower-gardener
authored andcommitted
Add an example of combining sparsity & dynamic range quantization.
PiperOrigin-RevId: 361198975
1 parent c35fc4c commit 392b0b3

File tree

1 file changed

+31
-0
lines changed
  • tensorflow_model_optimization/python/examples/sparsity/keras/mnist

1 file changed

+31
-0
lines changed

tensorflow_model_optimization/python/examples/sparsity/keras/mnist/mnist_e2e.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,37 @@ def main(unused_argv):
164164
print('evaluate 1x4 model')
165165
print(keras_test_utils.eval_mnist_tflite(model_content=tflite_model))
166166

167+
##############################################################################
168+
# Train and convert a model with 1x16 block config, and enable post-training
169+
# dynamic range quantization during conversion.
170+
##############################################################################
171+
pruning_params = {
172+
'pruning_schedule':
173+
ConstantSparsity(FLAGS.sparsity, begin_step=0, frequency=100),
174+
# TFLite transposes the weight during conversion, so we need to specify
175+
# the block as (16, 1) in the training API.
176+
'block_size': (16, 1)
177+
}
178+
179+
model = build_layerwise_model(input_shape, **pruning_params)
180+
model = train(model, x_train, y_train, x_test, y_test)
181+
182+
converter = tf.lite.TFLiteConverter.from_keras_model(model)
183+
converter.optimizations = {
184+
tf.lite.Optimize.DEFAULT, tf.lite.Optimize.EXPERIMENTAL_SPARSITY
185+
}
186+
187+
tflite_model = converter.convert()
188+
# Check the model is compressed
189+
print('Compression ratio: ', len(tflite_model) / len(tflite_model_dense))
190+
191+
tflite_model_path = '/tmp/sparse_mnist_%s_1x16.tflite' % FLAGS.sparsity
192+
with open(tflite_model_path, 'wb') as f:
193+
f.write(tflite_model)
194+
195+
print('evaluate 1x16 model')
196+
print(keras_test_utils.eval_mnist_tflite(model_content=tflite_model))
197+
167198

168199
if __name__ == '__main__':
169200
absl_app.run(main)

0 commit comments

Comments
 (0)