|
| 1 | +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# ============================================================================== |
| 15 | +"""Tool to quickly prune a Keras model for evaluation purpose. |
| 16 | +
|
| 17 | +Prunes the model with the given spasity parameters, without retraining. Will |
| 18 | +output a converted TFLite model for both pruned and unpruned versions. |
| 19 | +
|
| 20 | +This tool is intented to produce sparsified models for evaluating the |
| 21 | +performance benefits (model size, inference time, …) of pruning. Since the |
| 22 | +sparsity is applied in one shot, without retrainig, the accuracy of the |
| 23 | +resulting model will be severly degraded. |
| 24 | +""" |
| 25 | + |
| 26 | +from __future__ import print_function |
| 27 | + |
| 28 | +import os |
| 29 | +import tempfile |
| 30 | +import textwrap |
| 31 | +import zipfile |
| 32 | + |
| 33 | +from absl import app |
| 34 | +from absl import flags |
| 35 | +import tensorflow as tf |
| 36 | + |
| 37 | +from tensorflow_model_optimization.python.core.sparsity.keras import prune |
| 38 | +from tensorflow_model_optimization.python.core.sparsity.keras.tools import sparsity_tooling |
| 39 | + |
| 40 | + |
| 41 | +_MODEL_PATH = flags.DEFINE_string('model', None, 'Keras model file to prune') |
| 42 | +_OUTPUT_DIR = flags.DEFINE_string('output_dir', None, 'Output directory') |
| 43 | +_SPARSITY = flags.DEFINE_float( |
| 44 | + 'sparsity', |
| 45 | + 0.8, |
| 46 | + 'Target sparsity level, as float in [0,1] interval', |
| 47 | + lower_bound=0, |
| 48 | + upper_bound=1) |
| 49 | +_BLOCK_SIZE = flags.DEFINE_string( |
| 50 | + 'block_size', '1,1', |
| 51 | + 'Comma-separated dimensions (height,weight) of the block sparsity pattern.' |
| 52 | +) |
| 53 | + |
| 54 | + |
| 55 | +def _parse_block_size_flag(value): |
| 56 | + height_str, weight_str = value.split(',') |
| 57 | + return int(height_str), int(weight_str) |
| 58 | + |
| 59 | + |
| 60 | +@flags.validator(_BLOCK_SIZE.name) |
| 61 | +def _check_block_size(flag_value): |
| 62 | + try: |
| 63 | + _parse_block_size_flag(flag_value) |
| 64 | + return True |
| 65 | + except: |
| 66 | + raise flags.ValidationError('Invalid block size value "%s".' % flag_value) |
| 67 | + |
| 68 | + |
| 69 | +def convert_to_tflite(keras_model, output_path): |
| 70 | + """Converts the given Keras model to TFLite and write it to a file.""" |
| 71 | + converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) |
| 72 | + converter.optimizations = {tf.lite.Optimize.EXPERIMENTAL_SPARSITY} |
| 73 | + |
| 74 | + with open(output_path, 'wb') as out: |
| 75 | + out.write(converter.convert()) |
| 76 | + |
| 77 | + |
| 78 | +def get_gzipped_size(model_path): |
| 79 | + """Measures the compressed size of a model.""" |
| 80 | + with tempfile.TemporaryFile(suffix='.zip') as zipped_file: |
| 81 | + with zipfile.ZipFile( |
| 82 | + zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f: |
| 83 | + f.write(model_path) |
| 84 | + |
| 85 | + zipped_file.seek(0, 2) |
| 86 | + return os.fstat(zipped_file.fileno()).st_size |
| 87 | + |
| 88 | + |
| 89 | +def pruned_model_filename(sparsity, block_size): |
| 90 | + """Produces a human-readable name including sparsity parameters.""" |
| 91 | + return 'pruned_model_sparsity_%.2f_block_%s.tflite' % ( |
| 92 | + sparsity, '%dx%d' % block_size) |
| 93 | + |
| 94 | + |
| 95 | +def run(input_model_path, output_dir, target_sparsity, block_size): |
| 96 | + """Prunes the model and converts both pruned and unpruned versions to TFLite.""" |
| 97 | + |
| 98 | + print(textwrap.dedent("""\ |
| 99 | + Warning: The sparse models produced by this tool have poor accuracy. They |
| 100 | + are not intended to be served in production, but to be used for |
| 101 | + performance benchmarking.""")) |
| 102 | + |
| 103 | + input_model = tf.keras.models.load_model(input_model_path) |
| 104 | + |
| 105 | + os.makedirs(output_dir, exist_ok=True) |
| 106 | + unpruned_tflite_path = os.path.join( |
| 107 | + output_dir, 'unpruned_model.tflite') |
| 108 | + pruned_tflite_path = os.path.join( |
| 109 | + output_dir, pruned_model_filename(target_sparsity, block_size)) |
| 110 | + |
| 111 | + # Convert to TFLite without pruning |
| 112 | + convert_to_tflite(input_model, unpruned_tflite_path) |
| 113 | + |
| 114 | + # Prune and convert to TFLite |
| 115 | + pruned_model = sparsity_tooling.prune_for_benchmark( |
| 116 | + keras_model=input_model, |
| 117 | + target_sparsity=target_sparsity, |
| 118 | + block_size=block_size) |
| 119 | + stripped_model = prune.strip_pruning(pruned_model) # Remove pruning wrapper |
| 120 | + convert_to_tflite(stripped_model, pruned_tflite_path) |
| 121 | + |
| 122 | + # Measure the compressed size of unpruned vs pruned TFLite models |
| 123 | + unpruned_compressed_size = get_gzipped_size(unpruned_tflite_path) |
| 124 | + pruned_compressed_size = get_gzipped_size(pruned_tflite_path) |
| 125 | + print('Size of gzipped TFLite models:') |
| 126 | + print(' * Unpruned : %.2fMiB' % (unpruned_compressed_size / (2.**20))) |
| 127 | + print(' * Pruned : %.2fMiB' % (pruned_compressed_size / (2.**20))) |
| 128 | + print(' diff : %d%%' % |
| 129 | + (100. * (pruned_compressed_size - unpruned_compressed_size) / |
| 130 | + unpruned_compressed_size)) |
| 131 | + |
| 132 | + |
| 133 | +def main(argv): |
| 134 | + if len(argv) > 1: |
| 135 | + raise app.UsageError('Too many command-line arguments.') |
| 136 | + |
| 137 | + block_size = _parse_block_size_flag(_BLOCK_SIZE.value) |
| 138 | + run(_MODEL_PATH.value, _OUTPUT_DIR.value, _SPARSITY.value, block_size) |
| 139 | + |
| 140 | + |
| 141 | +if __name__ == '__main__': |
| 142 | + flags.mark_flag_as_required(_MODEL_PATH.name) |
| 143 | + flags.mark_flag_as_required(_OUTPUT_DIR.name) |
| 144 | + |
| 145 | + app.run(main) |
0 commit comments