@@ -53,6 +53,11 @@ class PrunePreserveQuantizeRegistry(object):
53
53
layers .Dense :
54
54
_PrunePreserveInfo (['kernel' ], ['Default8BitQuantizeConfig' ]),
55
55
56
+ # DepthwiseConv2D is supported with 8bit qat, but not with prune,
57
+ # thus for DepthwiseConv2D PQAT, weights sparsity preserve is disabled.
58
+ layers .DepthwiseConv2D :
59
+ _PrunePreserveInfo (['depthwise_kernel' ], ['Default8BitQuantizeConfig' ]),
60
+
56
61
# layers that supported with prune, but not yet with qat
57
62
# layers.Conv1D:
58
63
# _PrunePreserveInfo(['kernel'], []),
@@ -67,10 +72,6 @@ class PrunePreserveQuantizeRegistry(object):
67
72
# layers.LocallyConnected2D:
68
73
# _PrunePreserveInfo(['kernel'], ['Default8BitQuantizeConfig']),
69
74
70
- # DepthwiseCon2D is supported with 8bit qat, but not with prune
71
- # layers.DepthwiseConv2D:
72
- # _PrunePreserveInfo(['depthwise_kernel'], ['Default8BitConvQuantizeConfig']),
73
-
74
75
# SeparableConv need verify from 8bit qat
75
76
# layers.SeparableConv1D:
76
77
# _PrunePreserveInfo(['pointwise_kernel'], ['Default8BitConvQuantizeConfig']),
@@ -81,6 +82,10 @@ class PrunePreserveQuantizeRegistry(object):
81
82
# layers.Embedding: _PrunePreserveInfo(['embeddings'], []),
82
83
}
83
84
85
+ _DISABLE_PRUNE_PRESERVE = {
86
+ layers .DepthwiseConv2D ,
87
+ }
88
+
84
89
def __init__ (self ):
85
90
86
91
self ._config_quantizer_map = {
@@ -103,6 +108,19 @@ def _no_trainable_weights(cls, layer):
103
108
104
109
return len (layer .trainable_weights ) == 0
105
110
111
+ @classmethod
112
+ def _disable_prune_preserve (cls , layer ):
113
+ """Returns whether disable this layer for prune preserve.
114
+
115
+ Args:
116
+ layer: The layer to check for disable.
117
+
118
+ Returns:
119
+ True/False whether disable this layer for prune preserve.
120
+ """
121
+
122
+ return layer .__class__ in cls ._DISABLE_PRUNE_PRESERVE
123
+
106
124
@classmethod
107
125
def supports (cls , layer ):
108
126
"""Returns whether the registry supports this layer type.
@@ -174,7 +192,7 @@ def apply_sparsity_preserve_quantize_config(self, layer, quantize_config):
174
192
Returns quantize_config with addon sparsity preserve weight_quantizer.
175
193
"""
176
194
if self .supports (layer ):
177
- if self ._no_trainable_weights (layer ):
195
+ if self ._no_trainable_weights (layer ) or self . _disable_prune_preserve ( layer ) :
178
196
return quantize_config
179
197
if (quantize_config .__class__ .__name__
180
198
in self ._LAYERS_CONFIG_MAP [layer .__class__ ].quantize_config_attrs ):
0 commit comments