2929from tensorflow_model_optimization .python .core .quantization .keras .default_8bit import default_8bit_transforms
3030from tensorflow_model_optimization .python .core .quantization .keras .graph_transformations import model_transformer
3131from tensorflow_model_optimization .python .core .quantization .keras .layers import conv_batchnorm_test_utils
32+ from tensorflow_model_optimization .python .core .quantization .keras .layers import dense_batchnorm_test_utils
3233
3334ModelTransformer = model_transformer .ModelTransformer
3435
3536Conv2DModel = conv_batchnorm_test_utils .Conv2DModel
3637DepthwiseConv2DModel = conv_batchnorm_test_utils .DepthwiseConv2DModel
38+ DenseModel = dense_batchnorm_test_utils .DenseModel
3739
3840keras = tf .keras
3941
@@ -73,22 +75,27 @@ def _get_model(
7375 post_bn_activation = activation ,
7476 squeeze_type = squeeze_type ,
7577 normalization_type = normalization_type )
78+ elif layer_type == 'Dense' :
79+ return DenseModel .get_nonfolded_batchnorm_model (
80+ post_bn_activation = activation , normalization_type = normalization_type )
7681
7782 def _get_input_shape (self , layer_type ):
7883 if layer_type == 'Conv2D' :
7984 return Conv2DModel .get_batched_input_shape ()
8085 elif layer_type == 'DepthwiseConv2D' :
8186 return DepthwiseConv2DModel .get_batched_input_shape ()
87+ elif layer_type == 'Dense' :
88+ return DenseModel .get_batched_input_shape ()
8289
83- def _test_conv_squeeze_bn_activation_transform (
90+ def _test_conv_squeeze_or_dense_bn_activation_transform (
8491 self ,
8592 layer_type ,
86- squeeze_type ,
8793 normalization_type ,
8894 activation_type ,
8995 transform_class ,
90- conv_activation_class ,
91- normalization_quantize_config_class ):
96+ conv_or_dense_activation_class ,
97+ normalization_quantize_config_class ,
98+ squeeze_type = None ):
9299 model = self ._get_model (layer_type ,
93100 squeeze_type ,
94101 normalization_type ,
@@ -100,14 +107,14 @@ def _test_conv_squeeze_bn_activation_transform(
100107 [transform_class ()],
101108 ).transform ()
102109
103- conv_layer = transformed_model .layers [1 ]
110+ conv_or_dense_layer = transformed_model .layers [1 ]
104111 if squeeze_type == 'sepconv1d_squeeze' :
105112 bn_layer = transformed_model .layers [3 ]
106113 else :
107114 bn_layer = transformed_model .layers [2 ]
108115
109- self .assertIsInstance (
110- conv_layer . activation , conv_activation_class )
116+ self .assertIsInstance (conv_or_dense_layer . activation ,
117+ conv_or_dense_activation_class )
111118 self .assertIsInstance (
112119 updated_metadata .get (bn_layer .name ).get ('quantize_config' ),
113120 normalization_quantize_config_class )
@@ -123,15 +130,15 @@ def _test_conv_squeeze_bn_activation_transform(
123130 ('DepthwiseConv2D' , 'SyncBatchNormalization' ),
124131 )
125132 def testConv2DBatchNormQuantize (self , layer_type , normalization_type ):
126- self ._test_conv_squeeze_bn_activation_transform (
133+ self ._test_conv_squeeze_or_dense_bn_activation_transform (
127134 layer_type = layer_type ,
128135 squeeze_type = None ,
129136 normalization_type = normalization_type ,
130137 activation_type = None ,
131138 transform_class = default_8bit_transforms .Conv2DBatchNormQuantize ,
132- conv_activation_class = quantize_aware_activation .NoOpActivation ,
133- normalization_quantize_config_class =
134- default_8bit_quantize_configs .Default8BitOutputQuantizeConfig )
139+ conv_or_dense_activation_class = quantize_aware_activation .NoOpActivation ,
140+ normalization_quantize_config_class = default_8bit_quantize_configs
141+ .Default8BitOutputQuantizeConfig )
135142
136143 @parameterized .parameters (
137144 ('Conv2D' , 'BatchNormalization' ),
@@ -140,16 +147,15 @@ def testConv2DBatchNormQuantize(self, layer_type, normalization_type):
140147 ('DepthwiseConv2D' , 'SyncBatchNormalization' ),
141148 )
142149 def testConv2DBatchNormReLUQuantize (self , layer_type , normalization_type ):
143- self ._test_conv_squeeze_bn_activation_transform (
150+ self ._test_conv_squeeze_or_dense_bn_activation_transform (
144151 layer_type = layer_type ,
145152 squeeze_type = None ,
146153 normalization_type = normalization_type ,
147154 activation_type = 'relu' ,
148- transform_class =
149- default_8bit_transforms .Conv2DBatchNormReLUQuantize ,
150- conv_activation_class = quantize_aware_activation .NoOpActivation ,
151- normalization_quantize_config_class =
152- default_8bit_quantize_configs .NoOpQuantizeConfig )
155+ transform_class = default_8bit_transforms .Conv2DBatchNormReLUQuantize ,
156+ conv_or_dense_activation_class = quantize_aware_activation .NoOpActivation ,
157+ normalization_quantize_config_class = default_8bit_quantize_configs
158+ .NoOpQuantizeConfig )
153159
154160 @parameterized .parameters (
155161 ('Conv2D' , 'BatchNormalization' ),
@@ -159,16 +165,16 @@ def testConv2DBatchNormReLUQuantize(self, layer_type, normalization_type):
159165 )
160166 def testConv2DBatchNormActivationQuantize (
161167 self , layer_type , normalization_type ):
162- self ._test_conv_squeeze_bn_activation_transform (
168+ self ._test_conv_squeeze_or_dense_bn_activation_transform (
163169 layer_type = layer_type ,
164170 squeeze_type = None ,
165171 normalization_type = normalization_type ,
166172 activation_type = 'act_relu' ,
167- transform_class =
168- default_8bit_transforms .Conv2DBatchNormActivationQuantize ,
169- conv_activation_class = quantize_aware_activation .NoOpActivation ,
170- normalization_quantize_config_class =
171- default_8bit_quantize_configs .NoOpQuantizeConfig )
173+ transform_class = default_8bit_transforms
174+ .Conv2DBatchNormActivationQuantize ,
175+ conv_or_dense_activation_class = quantize_aware_activation .NoOpActivation ,
176+ normalization_quantize_config_class = default_8bit_quantize_configs
177+ .NoOpQuantizeConfig )
172178
173179 @parameterized .parameters (
174180 ('Conv2D' , 'BatchNormalization' ),
@@ -178,16 +184,15 @@ def testConv2DBatchNormActivationQuantize(
178184 )
179185 def testConv2DReshapeBatchNormQuantize (
180186 self , layer_type , normalization_type ):
181- self ._test_conv_squeeze_bn_activation_transform (
187+ self ._test_conv_squeeze_or_dense_bn_activation_transform (
182188 layer_type = layer_type ,
183189 squeeze_type = 'sepconv1d_squeeze' ,
184190 normalization_type = normalization_type ,
185191 activation_type = False ,
186- transform_class =
187- default_8bit_transforms .Conv2DReshapeBatchNormQuantize ,
188- conv_activation_class = quantize_aware_activation .NoOpActivation ,
189- normalization_quantize_config_class =
190- default_8bit_quantize_configs .Default8BitOutputQuantizeConfig )
192+ transform_class = default_8bit_transforms .Conv2DReshapeBatchNormQuantize ,
193+ conv_or_dense_activation_class = quantize_aware_activation .NoOpActivation ,
194+ normalization_quantize_config_class = default_8bit_quantize_configs
195+ .Default8BitOutputQuantizeConfig )
191196
192197 @parameterized .parameters (
193198 ('Conv2D' , 'BatchNormalization' ),
@@ -197,16 +202,16 @@ def testConv2DReshapeBatchNormQuantize(
197202 )
198203 def testConv2DReshapeBatchNormReLUQuantize (
199204 self , layer_type , normalization_type ):
200- self ._test_conv_squeeze_bn_activation_transform (
205+ self ._test_conv_squeeze_or_dense_bn_activation_transform (
201206 layer_type = layer_type ,
202207 squeeze_type = 'sepconv1d_squeeze' ,
203208 normalization_type = normalization_type ,
204209 activation_type = 'relu' ,
205- transform_class =
206- default_8bit_transforms .Conv2DReshapeBatchNormReLUQuantize ,
207- conv_activation_class = quantize_aware_activation .NoOpActivation ,
208- normalization_quantize_config_class =
209- default_8bit_quantize_configs .NoOpQuantizeConfig )
210+ transform_class = default_8bit_transforms
211+ .Conv2DReshapeBatchNormReLUQuantize ,
212+ conv_or_dense_activation_class = quantize_aware_activation .NoOpActivation ,
213+ normalization_quantize_config_class = default_8bit_quantize_configs
214+ .NoOpQuantizeConfig )
210215
211216 @parameterized .parameters (
212217 ('Conv2D' , 'BatchNormalization' ),
@@ -216,16 +221,63 @@ def testConv2DReshapeBatchNormReLUQuantize(
216221 )
217222 def testConv2DReshapeBatchNormActivationQuantize (
218223 self , layer_type , normalization_type ):
219- self ._test_conv_squeeze_bn_activation_transform (
224+ self ._test_conv_squeeze_or_dense_bn_activation_transform (
220225 layer_type = layer_type ,
221226 squeeze_type = 'sepconv1d_squeeze' ,
222227 normalization_type = normalization_type ,
223228 activation_type = 'act_relu' ,
224- transform_class =
225- default_8bit_transforms .Conv2DReshapeBatchNormActivationQuantize ,
226- conv_activation_class = quantize_aware_activation .NoOpActivation ,
227- normalization_quantize_config_class =
228- default_8bit_quantize_configs .NoOpQuantizeConfig )
229+ transform_class = default_8bit_transforms
230+ .Conv2DReshapeBatchNormActivationQuantize ,
231+ conv_or_dense_activation_class = quantize_aware_activation .NoOpActivation ,
232+ normalization_quantize_config_class = default_8bit_quantize_configs
233+ .NoOpQuantizeConfig )
234+
235+ @parameterized .parameters (
236+ ('Dense' , 'BatchNormalization' ),
237+ ('Dense' , 'SyncBatchNormalization' ),
238+ )
239+ def testDenseBatchNormQuantize (self , layer_type , normalization_type ):
240+ self ._test_conv_squeeze_or_dense_bn_activation_transform (
241+ layer_type = layer_type ,
242+ squeeze_type = None ,
243+ normalization_type = normalization_type ,
244+ activation_type = None ,
245+ transform_class = default_8bit_transforms .DenseBatchNormQuantize ,
246+ conv_or_dense_activation_class = quantize_aware_activation .NoOpActivation ,
247+ normalization_quantize_config_class = default_8bit_quantize_configs
248+ .Default8BitOutputQuantizeConfig )
249+
250+ @parameterized .parameters (
251+ ('Dense' , 'BatchNormalization' ),
252+ ('Dense' , 'SyncBatchNormalization' ),
253+ )
254+ def testDenseBatchNormReLUQuantize (self , layer_type , normalization_type ):
255+ self ._test_conv_squeeze_or_dense_bn_activation_transform (
256+ layer_type = layer_type ,
257+ squeeze_type = None ,
258+ normalization_type = normalization_type ,
259+ activation_type = 'relu' ,
260+ transform_class = default_8bit_transforms .DenseBatchNormReLUQuantize ,
261+ conv_or_dense_activation_class = quantize_aware_activation .NoOpActivation ,
262+ normalization_quantize_config_class = default_8bit_quantize_configs
263+ .NoOpQuantizeConfig )
264+
265+ @parameterized .parameters (
266+ ('Dense' , 'BatchNormalization' ),
267+ ('Dense' , 'SyncBatchNormalization' ),
268+ )
269+ def testDenseBatchNormActivationQuantize (self , layer_type ,
270+ normalization_type ):
271+ self ._test_conv_squeeze_or_dense_bn_activation_transform (
272+ layer_type = layer_type ,
273+ squeeze_type = None ,
274+ normalization_type = normalization_type ,
275+ activation_type = 'act_relu' ,
276+ transform_class = default_8bit_transforms
277+ .DenseBatchNormActivationQuantize ,
278+ conv_or_dense_activation_class = quantize_aware_activation .NoOpActivation ,
279+ normalization_quantize_config_class = default_8bit_quantize_configs
280+ .NoOpQuantizeConfig )
229281
230282 @parameterized .named_parameters (
231283 ('padding_valid' , {'padding' : 'valid' }),
0 commit comments