Skip to content

Commit f676a3b

Browse files
Add N-bit QAT features (fake-quantization).
PiperOrigin-RevId: 395863467
1 parent 0f80b18 commit f676a3b

19 files changed

+2983
-7
lines changed

tensorflow_model_optimization/python/core/api/BUILD

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ py_strict_library(
1717
"quantization/keras/collaborative_optimizations/__init__.py",
1818
"quantization/keras/default_8bit/__init__.py",
1919
"quantization/keras/default_8bit/default_8bit_transforms/__init__.py",
20+
"quantization/keras/experimental/__init__.py",
21+
"quantization/keras/experimental/default_n_bit/__init__.py",
22+
"quantization/keras/experimental/default_n_bit/default_n_bit_transforms/__init__.py",
2023
"quantization/keras/graph_transformations/__init__.py",
2124
"quantization/keras/graph_transformations/model_transformer/__init__.py",
2225
"quantization/keras/graph_transformations/transforms/__init__.py",
@@ -46,6 +49,10 @@ py_strict_library(
4649
"//tensorflow_model_optimization/python/core/quantization/keras/default_8bit:default_8bit_quantize_registry",
4750
"//tensorflow_model_optimization/python/core/quantization/keras/default_8bit:default_8bit_quantize_scheme",
4851
"//tensorflow_model_optimization/python/core/quantization/keras/default_8bit:default_8bit_transforms",
52+
"//tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit:default_n_bit_quantize_layout_transform",
53+
"//tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit:default_n_bit_quantize_registry",
54+
"//tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit:default_n_bit_quantize_scheme",
55+
"//tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit:default_n_bit_transforms",
4956
"//tensorflow_model_optimization/python/core/quantization/keras/graph_transformations:model_transformer",
5057
"//tensorflow_model_optimization/python/core/quantization/keras/graph_transformations:transforms",
5158
"//tensorflow_model_optimization/python/core/sparsity/keras:prunable_layer",

tensorflow_model_optimization/python/core/api/quantization/keras/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from tensorflow_model_optimization.python.core.api.quantization.keras import default_8bit
2121
from tensorflow_model_optimization.python.core.api.quantization.keras import graph_transformations
2222
from tensorflow_model_optimization.python.core.api.quantization.keras import collaborative_optimizations
23+
from tensorflow_model_optimization.python.core.api.quantization.keras import experimental
2324

2425
# quantize all layers with default quantization implementation.
2526
from tensorflow_model_optimization.python.core.quantization.keras.quantize import quantize_model
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Module containing experimental quantization features."""
16+
# pylint: disable=g-bad-import-order
17+
18+
# submodules
19+
from tensorflow_model_optimization.python.core.api.quantization.keras.experimental import default_n_bit
20+
21+
# pylint: enable=g-bad-import-order
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Module containing N-bit default quantization scheme."""
16+
# pylint: disable=g-bad-import-order
17+
18+
# submodules
19+
from tensorflow_model_optimization.python.core.api.quantization.keras.experimental.default_n_bit import default_n_bit_transforms
20+
21+
# The N-bit default quantization scheme classes.
22+
from tensorflow_model_optimization.python.core.quantization.keras.experimental.default_n_bit.default_n_bit_quantize_scheme import DefaultNBitQuantizeScheme
23+
from tensorflow_model_optimization.python.core.quantization.keras.experimental.default_n_bit.default_n_bit_quantize_layout_transform import DefaultNBitQuantizeLayoutTransform
24+
from tensorflow_model_optimization.python.core.quantization.keras.experimental.default_n_bit.default_n_bit_quantize_registry import DefaultNBitQuantizeRegistry
25+
26+
# pylint: enable=g-bad-import-order
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Module containing N-bit default transforms."""
16+
17+
# The 8bit default transform classes.
18+
from tensorflow_model_optimization.python.core.quantization.keras.experimental.default_n_bit.default_n_bit_transforms import ConcatTransform
19+
from tensorflow_model_optimization.python.core.quantization.keras.experimental.default_n_bit.default_n_bit_transforms import ConcatTransform3Inputs
20+
from tensorflow_model_optimization.python.core.quantization.keras.experimental.default_n_bit.default_n_bit_transforms import ConcatTransform4Inputs
21+
from tensorflow_model_optimization.python.core.quantization.keras.experimental.default_n_bit.default_n_bit_transforms import ConcatTransform5Inputs
22+
from tensorflow_model_optimization.python.core.quantization.keras.experimental.default_n_bit.default_n_bit_transforms import ConcatTransform6Inputs
23+
from tensorflow_model_optimization.python.core.quantization.keras.experimental.default_n_bit.default_n_bit_transforms import Conv2DBatchNormActivationQuantize
24+
from tensorflow_model_optimization.python.core.quantization.keras.experimental.default_n_bit.default_n_bit_transforms import Conv2DBatchNormQuantize
25+
from tensorflow_model_optimization.python.core.quantization.keras.experimental.default_n_bit.default_n_bit_transforms import Conv2DBatchNormReLUQuantize
26+
from tensorflow_model_optimization.python.core.quantization.keras.experimental.default_n_bit.default_n_bit_transforms import Conv2DReshapeBatchNormActivationQuantize
27+
from tensorflow_model_optimization.python.core.quantization.keras.experimental.default_n_bit.default_n_bit_transforms import Conv2DReshapeBatchNormQuantize
28+
from tensorflow_model_optimization.python.core.quantization.keras.experimental.default_n_bit.default_n_bit_transforms import Conv2DReshapeBatchNormReLUQuantize
29+
from tensorflow_model_optimization.python.core.quantization.keras.experimental.default_n_bit.default_n_bit_transforms import InputLayerQuantize
30+
from tensorflow_model_optimization.python.core.quantization.keras.experimental.default_n_bit.default_n_bit_transforms import LayerReluActivationQuantize
31+
from tensorflow_model_optimization.python.core.quantization.keras.experimental.default_n_bit.default_n_bit_transforms import LayerReLUQuantize
32+
from tensorflow_model_optimization.python.core.quantization.keras.experimental.default_n_bit.default_n_bit_transforms import SeparableConv1DQuantize
33+
from tensorflow_model_optimization.python.core.quantization.keras.experimental.default_n_bit.default_n_bit_transforms import SeparableConvQuantize

tensorflow_model_optimization/python/core/quantization/keras/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ py_strict_library(
1919
"//tensorflow_model_optimization/python/core/quantization/keras/graph_transformations", # buildcleaner: keep
2020
"//tensorflow_model_optimization/python/core/quantization/keras/layers", # buildcleaner: keep
2121
"//tensorflow_model_optimization/python/core/quantization/keras/default_8bit", # buildcleaner: keep
22+
"//tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit", # buildcleaner: keep
2223
"//tensorflow_model_optimization/python/core/quantization/keras/collaborative_optimizations", # buildcleaner: keep
2324
],
2425
)
@@ -28,6 +29,7 @@ py_strict_library(
2829
srcs = ["quant_ops.py"],
2930
srcs_version = "PY3",
3031
deps = [
32+
# absl/logging dep1,
3133
# tensorflow dep1,
3234
# python:training tensorflow dep2,
3335
"//tensorflow_model_optimization/python/core/keras:compat",
@@ -252,6 +254,7 @@ py_strict_library(
252254
"//tensorflow_model_optimization/python/core/keras:metrics",
253255
"//tensorflow_model_optimization/python/core/quantization/keras/default_8bit:default_8bit_quantize_registry",
254256
"//tensorflow_model_optimization/python/core/quantization/keras/default_8bit:default_8bit_quantize_scheme",
257+
"//tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit:default_n_bit_quantize_registry",
255258
],
256259
)
257260

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
load("//tensorflow_model_optimization:tensorflow_model_optimization.bzl", "py_strict_library", "py_strict_test")
2+
3+
package(default_visibility = [
4+
"//tensorflow_model_optimization:__subpackages__",
5+
])
6+
7+
licenses(["notice"])
8+
9+
py_strict_library(
10+
name = "default_n_bit",
11+
srcs = [
12+
"__init__.py",
13+
],
14+
srcs_version = "PY3",
15+
deps = [],
16+
)
17+
18+
py_strict_library(
19+
name = "default_n_bit_quantizers",
20+
srcs = [
21+
"default_n_bit_quantizers.py",
22+
],
23+
srcs_version = "PY3",
24+
deps = [
25+
# tensorflow dep1,
26+
"//tensorflow_model_optimization/python/core/quantization/keras:quantizers",
27+
],
28+
)
29+
30+
py_test(
31+
name = "default_n_bit_quantizers_test",
32+
srcs = [
33+
"default_n_bit_quantizers_test.py",
34+
],
35+
python_version = "PY3",
36+
deps = [
37+
":default_n_bit_quantizers",
38+
# absl/testing:parameterized dep1,
39+
# tensorflow dep1,
40+
],
41+
)
42+
43+
py_strict_library(
44+
name = "default_n_bit_quantize_configs",
45+
srcs = [
46+
"default_n_bit_quantize_configs.py",
47+
],
48+
srcs_version = "PY3",
49+
deps = [
50+
"//tensorflow_model_optimization/python/core/quantization/keras:quantize_config",
51+
"//tensorflow_model_optimization/python/core/quantization/keras:quantizers",
52+
],
53+
)
54+
55+
py_strict_library(
56+
name = "default_n_bit_quantize_registry",
57+
srcs = [
58+
"default_n_bit_quantize_registry.py",
59+
],
60+
srcs_version = "PY3",
61+
deps = [
62+
# tensorflow dep1,
63+
"//tensorflow_model_optimization/python/core/quantization/keras:quantize_config",
64+
"//tensorflow_model_optimization/python/core/quantization/keras:quantize_registry",
65+
"//tensorflow_model_optimization/python/core/quantization/keras:quantizers",
66+
"//tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit:default_n_bit_quantize_configs",
67+
"//tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit:default_n_bit_quantizers",
68+
],
69+
)
70+
71+
py_test(
72+
name = "default_n_bit_quantize_registry_test",
73+
srcs = [
74+
"default_n_bit_quantize_registry_test.py",
75+
],
76+
python_version = "PY3",
77+
deps = [
78+
":default_n_bit_quantize_registry",
79+
# absl/testing:parameterized dep1,
80+
# numpy dep1,
81+
# tensorflow dep1,
82+
"//tensorflow_model_optimization/python/core/quantization/keras:quantizers",
83+
],
84+
)
85+
86+
py_library(
87+
name = "default_n_bit_transforms",
88+
srcs = [
89+
"default_n_bit_transforms.py",
90+
],
91+
srcs_version = "PY3",
92+
visibility = ["//visibility:public"],
93+
deps = [
94+
# numpy dep1,
95+
# tensorflow dep1,
96+
"//tensorflow_model_optimization/python/core/quantization/keras:quantize_aware_activation",
97+
"//tensorflow_model_optimization/python/core/quantization/keras:quantize_layer",
98+
"//tensorflow_model_optimization/python/core/quantization/keras:quantizers",
99+
"//tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit:default_n_bit_quantize_configs",
100+
"//tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit:default_n_bit_quantize_registry",
101+
"//tensorflow_model_optimization/python/core/quantization/keras/graph_transformations:transforms",
102+
],
103+
)
104+
105+
py_strict_library(
106+
name = "default_n_bit_quantize_layout_transform",
107+
srcs = [
108+
"default_n_bit_quantize_layout_transform.py",
109+
],
110+
srcs_version = "PY3",
111+
deps = [
112+
":default_n_bit_transforms",
113+
# tensorflow dep1,
114+
"//tensorflow_model_optimization/python/core/quantization/keras:quantize_layout_transform",
115+
"//tensorflow_model_optimization/python/core/quantization/keras/graph_transformations:model_transformer",
116+
],
117+
)
118+
119+
py_strict_test(
120+
name = "default_n_bit_transforms_test",
121+
size = "large",
122+
srcs = [
123+
"default_n_bit_transforms_test.py",
124+
],
125+
python_version = "PY3",
126+
deps = [
127+
":default_n_bit_quantize_configs",
128+
":default_n_bit_transforms",
129+
# absl/testing:parameterized dep1,
130+
# numpy dep1,
131+
# tensorflow dep1,
132+
"//tensorflow_model_optimization/python/core/quantization/keras:quantize_aware_activation",
133+
"//tensorflow_model_optimization/python/core/quantization/keras:quantize_layer",
134+
"//tensorflow_model_optimization/python/core/quantization/keras:quantizers",
135+
"//tensorflow_model_optimization/python/core/quantization/keras/graph_transformations:model_transformer",
136+
"//tensorflow_model_optimization/python/core/quantization/keras/layers:conv_batchnorm_test_utils",
137+
],
138+
)
139+
140+
py_strict_library(
141+
name = "default_n_bit_quantize_scheme",
142+
srcs = [
143+
"default_n_bit_quantize_scheme.py",
144+
],
145+
srcs_version = "PY3",
146+
visibility = ["//visibility:public"],
147+
deps = [
148+
":default_n_bit_quantize_layout_transform",
149+
":default_n_bit_quantize_registry",
150+
"//tensorflow_model_optimization/python/core/quantization/keras:quantize_scheme",
151+
],
152+
)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
This directory is modified based on default_8bit, which allows you to manually
2+
change the number of bits of weight and activation in QAT.
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Default N-Bit QuantizeConfigs."""
16+
17+
from typing import Any, Dict
18+
from tensorflow_model_optimization.python.core.quantization.keras import quantize_config
19+
from tensorflow_model_optimization.python.core.quantization.keras import quantizers
20+
21+
22+
class DefaultNBitOutputQuantizeConfig(quantize_config.QuantizeConfig):
23+
"""QuantizeConfig which only quantizes the output from a layer."""
24+
25+
def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
26+
self._num_bits_weight = num_bits_weight
27+
self._num_bits_activation = num_bits_activation
28+
29+
def get_weights_and_quantizers(self, layer):
30+
return []
31+
32+
def get_activations_and_quantizers(self, layer):
33+
return []
34+
35+
def set_quantize_weights(self, layer, quantize_weights):
36+
pass
37+
38+
def set_quantize_activations(self, layer, quantize_activations):
39+
pass
40+
41+
def get_output_quantizers(self, layer):
42+
return [quantizers.MovingAverageQuantizer(
43+
num_bits=self._num_bits_activation, per_axis=False,
44+
symmetric=False, narrow_range=False)] # activation/output
45+
46+
def get_config(self) -> Dict[str, Any]:
47+
return {
48+
'num_bits_weight': self._num_bits_weight,
49+
'num_bits_activation': self._num_bits_activation,
50+
}
51+
52+
53+
class NoOpQuantizeConfig(quantize_config.QuantizeConfig):
54+
"""QuantizeConfig which does not quantize any part of the layer."""
55+
56+
def get_weights_and_quantizers(self, layer):
57+
return []
58+
59+
def get_activations_and_quantizers(self, layer):
60+
return []
61+
62+
def set_quantize_weights(self, layer, quantize_weights):
63+
pass
64+
65+
def set_quantize_activations(self, layer, quantize_activations):
66+
pass
67+
68+
def get_output_quantizers(self, layer):
69+
return []
70+
71+
def get_config(self):
72+
return {}

0 commit comments

Comments
 (0)