Skip to content

Commit c1be7be

Browse files
committed
Enable Prune Aware QAT through QuantizationScheme.
Change-Id: I2cd50868696b277a94cbcc50b167c8eb1a6612d7
1 parent e39363d commit c1be7be

File tree

8 files changed

+106
-9
lines changed

8 files changed

+106
-9
lines changed

tensorflow_model_optimization/python/core/api/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ py_library(
1212
"quantization/keras/__init__.py",
1313
"quantization/keras/default_8bit/__init__.py",
1414
"quantization/keras/quantizers/__init__.py",
15+
"quantization/keras/experimental_scheme/__init__.py",
1516
"sparsity/__init__.py",
1617
"sparsity/keras/__init__.py",
1718
],

tensorflow_model_optimization/python/core/api/quantization/keras/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
# submodules
1919
from tensorflow_model_optimization.python.core.api.quantization.keras import quantizers
2020
from tensorflow_model_optimization.python.core.api.quantization.keras import default_8bit
21+
from tensorflow_model_optimization.python.core.api.quantization.keras import experimental_scheme
2122

2223
# quantize all layers with default quantization implementation.
2324
from tensorflow_model_optimization.python.core.quantization.keras.quantize import quantize_model
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Module containing experimental Quantization schemes."""
16+
17+
from tensorflow_model_optimization.python.core.quantization.keras.prune_preserve.default_8bit_prune_preserve_quantize_scheme import (
18+
Default8BitPrunePreserveQuantizeScheme, )

tensorflow_model_optimization/python/core/quantization/keras/prune_preserve/BUILD

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ py_library(
1010
"__init__.py",
1111
],
1212
srcs_version = "PY3",
13-
deps = [],
13+
deps = [
14+
":default_8bit_prune_preserve_quantize_scheme",
15+
],
1416
)
1517

1618
py_library(
@@ -22,6 +24,7 @@ py_library(
2224
deps = [
2325
# tensorflow dep1,
2426
"//tensorflow_model_optimization/python/core/quantization/keras/default_8bit:default_8bit_quantizers",
27+
"//tensorflow_model_optimization/python/core/quantization/keras/default_8bit:default_8bit_quantize_registry",
2528
],
2629
)
2730

@@ -37,4 +40,17 @@ py_test(
3740
"//tensorflow_model_optimization/python/core/sparsity/keras:prune_registry",
3841
"//tensorflow_model_optimization/python/core/quantization/keras/default_8bit:default_8bit_quantize_registry",
3942
]
43+
)
44+
45+
py_library(
46+
name = "default_8bit_prune_preserve_quantize_scheme",
47+
srcs = [
48+
"default_8bit_prune_preserve_quantize_scheme.py",
49+
],
50+
srcs_version = "PY3",
51+
visibility = ["//visibility:public"],
52+
deps = [
53+
":prune_preserve_quantize_registry",
54+
"//tensorflow_model_optimization/python/core/quantization/keras/default_8bit:default_8bit_quantize_scheme",
55+
],
4056
)
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Default 8 bit Prune Preserve Quantization scheme which specifies how quantization should be applied."""
16+
17+
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import (
18+
default_8bit_quantize_scheme, )
19+
from tensorflow_model_optimization.python.core.quantization.keras.prune_preserve import (
20+
prune_preserve_quantize_registry, )
21+
22+
23+
class Default8BitPrunePreserveQuantizeScheme(
24+
default_8bit_quantize_scheme.Default8BitQuantizeScheme):
25+
"""Default 8 bit Prune Preserve Quantization Scheme."""
26+
def get_layout_transformer(self):
27+
return super(Default8BitPrunePreserveQuantizeScheme, self).get_layout_transformer()
28+
29+
def get_quantize_registry(self):
30+
return prune_preserve_quantize_registry.Default8bitPrunePreserveQuantizeRegistry()
31+

tensorflow_model_optimization/python/core/quantization/keras/prune_preserve/prune_preserve_quantize_registry.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
from tensorflow_model_optimization.python.core.quantization.keras import quant_ops
2020
from tensorflow_model_optimization.python.core.quantization.keras import quantizers
2121
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import (
22-
default_8bit_quantizers)
22+
default_8bit_quantizers, )
23+
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import (
24+
default_8bit_quantize_registry, )
2325

2426
layers = tf.keras.layers
2527

@@ -179,16 +181,40 @@ def apply_sparsity_preserve_quantize_config(self, layer, quantize_config):
179181
quantize_config.weight_quantizer = self._config_quantizer_map[
180182
quantize_config.__class__.__name__]
181183
else:
182-
raise ValueError('Configuration ' +
183-
str(quantize_config.__class__.__name__) +
184-
' is not supported for Layer ' +
185-
str(layer.__class__) + '.')
184+
raise ValueError("Configuration " +
185+
str(quantize_config.__class__.__name__) +
186+
" is not supported for Layer " +
187+
str(layer.__class__) + ".")
186188
else:
187-
raise ValueError('Layer ' + str(layer.__class__) + ' is not supported.')
189+
raise ValueError("Layer " + str(layer.__class__) + " is not supported.")
188190

189191
return quantize_config
190192

191193

194+
class Default8bitPrunePreserveQuantizeRegistry(PrunePreserveQuantizeRegistry):
195+
"""Default 8 bit PrunePreserveQuantizeRegistry."""
196+
def __init__(self):
197+
super(Default8bitPrunePreserveQuantizeRegistry, self).__init__()
198+
199+
def get_quantize_config(self, layer):
200+
"""Returns the quantization config with addon sparsity
201+
preserve weight_quantizer for the given layer.
202+
203+
Args:
204+
layer: input layer to return quantize config for.
205+
206+
Returns:
207+
Returns the quantization config with sparsity preserve weight_quantizer.
208+
"""
209+
quantize_config = default_8bit_quantize_registry.QuantizeRegistry(
210+
).get_quantize_config(layer)
211+
prune_aware_quantize_config = super(
212+
Default8bitPrunePreserveQuantizeRegistry,
213+
self).apply_sparsity_preserve_quantize_config(layer, quantize_config)
214+
215+
return prune_aware_quantize_config
216+
217+
192218
class PrunePerserveDefaultWeightsQuantizer(quantizers.LastValueQuantizer):
193219
"""Quantize weights while preserve sparsity."""
194220
def __init__(self, num_bits, per_axis, symmetric, narrow_range):

tensorflow_model_optimization/python/examples/quantization_with_sparsity/keras/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,6 @@ py_binary(
1212
# tensorflow dep1,
1313
# python/keras tensorflow dep2,
1414
"//tensorflow_model_optimization/python/core/quantization/keras:quantize",
15+
"//tensorflow_model_optimization/python/core/quantization/keras/prune_preserve:default_8bit_prune_preserve_quantize_scheme",
1516
],
1617
)

tensorflow_model_optimization/python/examples/quantization_with_sparsity/keras/mnist_cnn.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_callbacks
2727
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule
2828
from tensorflow_model_optimization.python.core.quantization.keras import quantize
29+
from tensorflow_model_optimization.python.core.quantization.keras.prune_preserve import (
30+
default_8bit_prune_preserve_quantize_scheme, )
2931

3032
layers = tf.keras.layers
3133

@@ -118,8 +120,9 @@ def prune_preserve_quantize_model(pruned_model, train_images, train_labels):
118120
pruned_model = prune.strip_pruning(pruned_model)
119121
# Prune preserve QAT model
120122
quant_aware_annotate_model = quantize.quantize_annotate_model(pruned_model)
121-
quant_aware_model = quantize.quantize_apply(quant_aware_annotate_model,
122-
prune_preserve=True)
123+
quant_aware_model = quantize.quantize_apply(
124+
quant_aware_annotate_model,
125+
scheme=default_8bit_prune_preserve_quantize_scheme.Default8BitPrunePreserveQuantizeScheme())
123126
quant_aware_model.summary()
124127

125128
fit_kwargs = {

0 commit comments

Comments
 (0)