Skip to content

Commit be98db9

Browse files
committed
[Collaborative_optimizations] Cluster-preserving quantization aware training
1 parent c35fc4c commit be98db9

File tree

12 files changed

+1114
-2
lines changed

12 files changed

+1114
-2
lines changed

tensorflow_model_optimization/python/core/api/experimental/combine/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
1+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -13,5 +13,12 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515
"""Module containing collaborative optimization code."""
16+
1617
from tensorflow_model_optimization.python.core.quantization.keras.collaborative_optimizations.prune_preserve.default_8bit_prune_preserve_quantize_scheme import (
1718
Default8BitPrunePreserveQuantizeScheme,)
19+
20+
from tensorflow_model_optimization.python.core.quantization.keras.collaborative_optimizations.cluster_preserve.default_8bit_cluster_preserve_quantize_scheme import (
21+
Default8BitClusterPreserveQuantizeScheme,)
22+
23+
from tensorflow_model_optimization.python.core.quantization.keras.collaborative_optimizations.cluster_preserve.cluster_utils import (
24+
strip_clustering_cqat,)

tensorflow_model_optimization/python/core/quantization/keras/collaborative_optimizations/BUILD

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,8 @@ py_library(
66
name = "collaborative_optimizations",
77
srcs = ["__init__.py"],
88
srcs_version = "PY3",
9-
deps = ["//tensorflow_model_optimization/python/core/quantization/keras/collaborative_optimizations/prune_preserve"],
9+
deps = [
10+
"//tensorflow_model_optimization/python/core/quantization/keras/collaborative_optimizations/prune_preserve",
11+
"//tensorflow_model_optimization/python/core/quantization/keras/collaborative_optimizations/cluster_preserve",
12+
],
1013
)
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
package(default_visibility = [
2+
"//tensorflow_model_optimization:__subpackages__",
3+
])
4+
5+
licenses(["notice"]) # Apache 2.0
6+
7+
py_library(
8+
name = "cluster_preserve",
9+
srcs = [
10+
"__init__.py",
11+
],
12+
srcs_version = "PY3",
13+
deps = [
14+
":default_8bit_cluster_preserve_quantize_scheme",
15+
],
16+
)
17+
18+
py_library(
19+
name = "cluster_utils",
20+
srcs = [
21+
"cluster_utils.py",
22+
],
23+
srcs_version = "PY3",
24+
visibility = ["//visibility:private"],
25+
deps = [
26+
# tensorflow dep1,
27+
"//tensorflow_model_optimization/python/core/clustering/keras:clustering_registry",
28+
],
29+
)
30+
31+
py_library(
32+
name = "cluster_preserve_quantize_registry",
33+
srcs = [
34+
"cluster_preserve_quantize_registry.py",
35+
],
36+
srcs_version = "PY3",
37+
visibility = ["//visibility:private"],
38+
deps = [
39+
# tensorflow dep1,
40+
"//tensorflow_model_optimization/python/core/quantization/keras/default_8bit:default_8bit_quantizers",
41+
"//tensorflow_model_optimization/python/core/quantization/keras/default_8bit:default_8bit_quantize_registry",
42+
":cluster_utils",
43+
],
44+
)
45+
46+
py_test(
47+
name = "cluster_preserve_quantize_registry_test",
48+
srcs = [
49+
"cluster_preserve_quantize_registry_test.py",
50+
],
51+
python_version = "PY3",
52+
visibility = ["//visibility:private"],
53+
deps = [
54+
# tensorflow dep1,
55+
":cluster_preserve_quantize_registry",
56+
"//tensorflow_model_optimization/python/core/quantization/keras/default_8bit:default_8bit_quantize_registry",
57+
"//tensorflow_model_optimization/python/core/quantization/keras:quantize_config",
58+
]
59+
)
60+
61+
py_library(
62+
name = "default_8bit_cluster_preserve_quantize_scheme",
63+
srcs = [
64+
"default_8bit_cluster_preserve_quantize_scheme.py",
65+
],
66+
srcs_version = "PY3",
67+
visibility = ["//visibility:private"],
68+
deps = [
69+
":cluster_preserve_quantize_registry",
70+
"//tensorflow_model_optimization/python/core/quantization/keras/default_8bit:default_8bit_quantize_scheme",
71+
],
72+
)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================

0 commit comments

Comments
 (0)