Skip to content

Commit bfa7b3c

Browse files
committed
Register DepthwiseConv2d in PQAT.
Change-Id: I753eea85c2ecd1fcda54be42def58e144229d388
1 parent c1be7be commit bfa7b3c

File tree

1 file changed

+23
-5
lines changed

1 file changed

+23
-5
lines changed

tensorflow_model_optimization/python/core/quantization/keras/prune_preserve/prune_preserve_quantize_registry.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,11 @@ class PrunePreserveQuantizeRegistry(object):
5353
layers.Dense:
5454
_PrunePreserveInfo(['kernel'], ['Default8BitQuantizeConfig']),
5555

56+
# DepthwiseConv2D is supported with 8bit qat, but not with prune,
57+
# thus for DepthwiseConv2D PQAT, weights sparsity preserve is disabled.
58+
layers.DepthwiseConv2D:
59+
_PrunePreserveInfo(['depthwise_kernel'], ['Default8BitQuantizeConfig']),
60+
5661
# layers that supported with prune, but not yet with qat
5762
# layers.Conv1D:
5863
# _PrunePreserveInfo(['kernel'], []),
@@ -67,10 +72,6 @@ class PrunePreserveQuantizeRegistry(object):
6772
# layers.LocallyConnected2D:
6873
# _PrunePreserveInfo(['kernel'], ['Default8BitQuantizeConfig']),
6974

70-
# DepthwiseCon2D is supported with 8bit qat, but not with prune
71-
# layers.DepthwiseConv2D:
72-
# _PrunePreserveInfo(['depthwise_kernel'], ['Default8BitConvQuantizeConfig']),
73-
7475
# SeparableConv need verify from 8bit qat
7576
# layers.SeparableConv1D:
7677
# _PrunePreserveInfo(['pointwise_kernel'], ['Default8BitConvQuantizeConfig']),
@@ -81,6 +82,10 @@ class PrunePreserveQuantizeRegistry(object):
8182
# layers.Embedding: _PrunePreserveInfo(['embeddings'], []),
8283
}
8384

85+
_DISABLE_PRUNE_PRESERVE = {
86+
layers.DepthwiseConv2D,
87+
}
88+
8489
def __init__(self):
8590

8691
self._config_quantizer_map = {
@@ -103,6 +108,19 @@ def _no_trainable_weights(cls, layer):
103108

104109
return len(layer.trainable_weights) == 0
105110

111+
@classmethod
112+
def _disable_prune_preserve(cls, layer):
113+
"""Returns whether disable this layer for prune preserve.
114+
115+
Args:
116+
layer: The layer to check for disable.
117+
118+
Returns:
119+
True/False whether disable this layer for prune preserve.
120+
"""
121+
122+
return layer.__class__ in cls._DISABLE_PRUNE_PRESERVE
123+
106124
@classmethod
107125
def supports(cls, layer):
108126
"""Returns whether the registry supports this layer type.
@@ -174,7 +192,7 @@ def apply_sparsity_preserve_quantize_config(self, layer, quantize_config):
174192
Returns quantize_config with addon sparsity preserve weight_quantizer.
175193
"""
176194
if self.supports(layer):
177-
if self._no_trainable_weights(layer):
195+
if self._no_trainable_weights(layer) or self._disable_prune_preserve(layer):
178196
return quantize_config
179197
if (quantize_config.__class__.__name__
180198
in self._LAYERS_CONFIG_MAP[layer.__class__].quantize_config_attrs):

0 commit comments

Comments
 (0)