@@ -53,6 +53,11 @@ class PrunePreserveQuantizeRegistry(object):
5353 layers .Dense :
5454 _PrunePreserveInfo (['kernel' ], ['Default8BitQuantizeConfig' ]),
5555
56+ # DepthwiseConv2D is supported with 8bit qat, but not with prune,
57+ # thus for DepthwiseConv2D PQAT, weights sparsity preserve is disabled.
58+ layers .DepthwiseConv2D :
59+ _PrunePreserveInfo (['depthwise_kernel' ], ['Default8BitQuantizeConfig' ]),
60+
5661 # layers that supported with prune, but not yet with qat
5762 # layers.Conv1D:
5863 # _PrunePreserveInfo(['kernel'], []),
@@ -67,10 +72,6 @@ class PrunePreserveQuantizeRegistry(object):
6772 # layers.LocallyConnected2D:
6873 # _PrunePreserveInfo(['kernel'], ['Default8BitQuantizeConfig']),
6974
70- # DepthwiseCon2D is supported with 8bit qat, but not with prune
71- # layers.DepthwiseConv2D:
72- # _PrunePreserveInfo(['depthwise_kernel'], ['Default8BitConvQuantizeConfig']),
73-
7475 # SeparableConv need verify from 8bit qat
7576 # layers.SeparableConv1D:
7677 # _PrunePreserveInfo(['pointwise_kernel'], ['Default8BitConvQuantizeConfig']),
@@ -81,6 +82,10 @@ class PrunePreserveQuantizeRegistry(object):
8182 # layers.Embedding: _PrunePreserveInfo(['embeddings'], []),
8283 }
8384
85+ _DISABLE_PRUNE_PRESERVE = {
86+ layers .DepthwiseConv2D ,
87+ }
88+
8489 def __init__ (self ):
8590
8691 self ._config_quantizer_map = {
@@ -103,6 +108,19 @@ def _no_trainable_weights(cls, layer):
103108
104109 return len (layer .trainable_weights ) == 0
105110
111+ @classmethod
112+ def _disable_prune_preserve (cls , layer ):
113+ """Returns whether disable this layer for prune preserve.
114+
115+ Args:
116+ layer: The layer to check for disable.
117+
118+ Returns:
119+ True/False whether disable this layer for prune preserve.
120+ """
121+
122+ return layer .__class__ in cls ._DISABLE_PRUNE_PRESERVE
123+
106124 @classmethod
107125 def supports (cls , layer ):
108126 """Returns whether the registry supports this layer type.
@@ -174,7 +192,7 @@ def apply_sparsity_preserve_quantize_config(self, layer, quantize_config):
174192 Returns quantize_config with addon sparsity preserve weight_quantizer.
175193 """
176194 if self .supports (layer ):
177- if self ._no_trainable_weights (layer ):
195+ if self ._no_trainable_weights (layer ) or self . _disable_prune_preserve ( layer ) :
178196 return quantize_config
179197 if (quantize_config .__class__ .__name__
180198 in self ._LAYERS_CONFIG_MAP [layer .__class__ ].quantize_config_attrs ):
0 commit comments