29
29
from tensorflow_model_optimization .python .core .quantization .keras .default_8bit import default_8bit_transforms
30
30
from tensorflow_model_optimization .python .core .quantization .keras .graph_transformations import model_transformer
31
31
from tensorflow_model_optimization .python .core .quantization .keras .layers import conv_batchnorm_test_utils
32
+ from tensorflow_model_optimization .python .core .quantization .keras .layers import dense_batchnorm_test_utils
32
33
33
34
ModelTransformer = model_transformer .ModelTransformer
34
35
35
36
Conv2DModel = conv_batchnorm_test_utils .Conv2DModel
36
37
DepthwiseConv2DModel = conv_batchnorm_test_utils .DepthwiseConv2DModel
38
+ DenseModel = dense_batchnorm_test_utils .DenseModel
37
39
38
40
keras = tf .keras
39
41
@@ -73,22 +75,27 @@ def _get_model(
73
75
post_bn_activation = activation ,
74
76
squeeze_type = squeeze_type ,
75
77
normalization_type = normalization_type )
78
+ elif layer_type == 'Dense' :
79
+ return DenseModel .get_nonfolded_batchnorm_model (
80
+ post_bn_activation = activation , normalization_type = normalization_type )
76
81
77
82
def _get_input_shape (self , layer_type ):
78
83
if layer_type == 'Conv2D' :
79
84
return Conv2DModel .get_batched_input_shape ()
80
85
elif layer_type == 'DepthwiseConv2D' :
81
86
return DepthwiseConv2DModel .get_batched_input_shape ()
87
+ elif layer_type == 'Dense' :
88
+ return DenseModel .get_batched_input_shape ()
82
89
83
- def _test_conv_squeeze_bn_activation_transform (
90
+ def _test_conv_squeeze_or_dense_bn_activation_transform (
84
91
self ,
85
92
layer_type ,
86
- squeeze_type ,
87
93
normalization_type ,
88
94
activation_type ,
89
95
transform_class ,
90
- conv_activation_class ,
91
- normalization_quantize_config_class ):
96
+ conv_or_dense_activation_class ,
97
+ normalization_quantize_config_class ,
98
+ squeeze_type = None ):
92
99
model = self ._get_model (layer_type ,
93
100
squeeze_type ,
94
101
normalization_type ,
@@ -100,14 +107,14 @@ def _test_conv_squeeze_bn_activation_transform(
100
107
[transform_class ()],
101
108
).transform ()
102
109
103
- conv_layer = transformed_model .layers [1 ]
110
+ conv_or_dense_layer = transformed_model .layers [1 ]
104
111
if squeeze_type == 'sepconv1d_squeeze' :
105
112
bn_layer = transformed_model .layers [3 ]
106
113
else :
107
114
bn_layer = transformed_model .layers [2 ]
108
115
109
- self .assertIsInstance (
110
- conv_layer . activation , conv_activation_class )
116
+ self .assertIsInstance (conv_or_dense_layer . activation ,
117
+ conv_or_dense_activation_class )
111
118
self .assertIsInstance (
112
119
updated_metadata .get (bn_layer .name ).get ('quantize_config' ),
113
120
normalization_quantize_config_class )
@@ -123,15 +130,15 @@ def _test_conv_squeeze_bn_activation_transform(
123
130
('DepthwiseConv2D' , 'SyncBatchNormalization' ),
124
131
)
125
132
def testConv2DBatchNormQuantize (self , layer_type , normalization_type ):
126
- self ._test_conv_squeeze_bn_activation_transform (
133
+ self ._test_conv_squeeze_or_dense_bn_activation_transform (
127
134
layer_type = layer_type ,
128
135
squeeze_type = None ,
129
136
normalization_type = normalization_type ,
130
137
activation_type = None ,
131
138
transform_class = default_8bit_transforms .Conv2DBatchNormQuantize ,
132
- conv_activation_class = quantize_aware_activation .NoOpActivation ,
133
- normalization_quantize_config_class =
134
- default_8bit_quantize_configs .Default8BitOutputQuantizeConfig )
139
+ conv_or_dense_activation_class = quantize_aware_activation .NoOpActivation ,
140
+ normalization_quantize_config_class = default_8bit_quantize_configs
141
+ .Default8BitOutputQuantizeConfig )
135
142
136
143
@parameterized .parameters (
137
144
('Conv2D' , 'BatchNormalization' ),
@@ -140,16 +147,15 @@ def testConv2DBatchNormQuantize(self, layer_type, normalization_type):
140
147
('DepthwiseConv2D' , 'SyncBatchNormalization' ),
141
148
)
142
149
def testConv2DBatchNormReLUQuantize (self , layer_type , normalization_type ):
143
- self ._test_conv_squeeze_bn_activation_transform (
150
+ self ._test_conv_squeeze_or_dense_bn_activation_transform (
144
151
layer_type = layer_type ,
145
152
squeeze_type = None ,
146
153
normalization_type = normalization_type ,
147
154
activation_type = 'relu' ,
148
- transform_class =
149
- default_8bit_transforms .Conv2DBatchNormReLUQuantize ,
150
- conv_activation_class = quantize_aware_activation .NoOpActivation ,
151
- normalization_quantize_config_class =
152
- default_8bit_quantize_configs .NoOpQuantizeConfig )
155
+ transform_class = default_8bit_transforms .Conv2DBatchNormReLUQuantize ,
156
+ conv_or_dense_activation_class = quantize_aware_activation .NoOpActivation ,
157
+ normalization_quantize_config_class = default_8bit_quantize_configs
158
+ .NoOpQuantizeConfig )
153
159
154
160
@parameterized .parameters (
155
161
('Conv2D' , 'BatchNormalization' ),
@@ -159,16 +165,16 @@ def testConv2DBatchNormReLUQuantize(self, layer_type, normalization_type):
159
165
)
160
166
def testConv2DBatchNormActivationQuantize (
161
167
self , layer_type , normalization_type ):
162
- self ._test_conv_squeeze_bn_activation_transform (
168
+ self ._test_conv_squeeze_or_dense_bn_activation_transform (
163
169
layer_type = layer_type ,
164
170
squeeze_type = None ,
165
171
normalization_type = normalization_type ,
166
172
activation_type = 'act_relu' ,
167
- transform_class =
168
- default_8bit_transforms .Conv2DBatchNormActivationQuantize ,
169
- conv_activation_class = quantize_aware_activation .NoOpActivation ,
170
- normalization_quantize_config_class =
171
- default_8bit_quantize_configs .NoOpQuantizeConfig )
173
+ transform_class = default_8bit_transforms
174
+ .Conv2DBatchNormActivationQuantize ,
175
+ conv_or_dense_activation_class = quantize_aware_activation .NoOpActivation ,
176
+ normalization_quantize_config_class = default_8bit_quantize_configs
177
+ .NoOpQuantizeConfig )
172
178
173
179
@parameterized .parameters (
174
180
('Conv2D' , 'BatchNormalization' ),
@@ -178,16 +184,15 @@ def testConv2DBatchNormActivationQuantize(
178
184
)
179
185
def testConv2DReshapeBatchNormQuantize (
180
186
self , layer_type , normalization_type ):
181
- self ._test_conv_squeeze_bn_activation_transform (
187
+ self ._test_conv_squeeze_or_dense_bn_activation_transform (
182
188
layer_type = layer_type ,
183
189
squeeze_type = 'sepconv1d_squeeze' ,
184
190
normalization_type = normalization_type ,
185
191
activation_type = False ,
186
- transform_class =
187
- default_8bit_transforms .Conv2DReshapeBatchNormQuantize ,
188
- conv_activation_class = quantize_aware_activation .NoOpActivation ,
189
- normalization_quantize_config_class =
190
- default_8bit_quantize_configs .Default8BitOutputQuantizeConfig )
192
+ transform_class = default_8bit_transforms .Conv2DReshapeBatchNormQuantize ,
193
+ conv_or_dense_activation_class = quantize_aware_activation .NoOpActivation ,
194
+ normalization_quantize_config_class = default_8bit_quantize_configs
195
+ .Default8BitOutputQuantizeConfig )
191
196
192
197
@parameterized .parameters (
193
198
('Conv2D' , 'BatchNormalization' ),
@@ -197,16 +202,16 @@ def testConv2DReshapeBatchNormQuantize(
197
202
)
198
203
def testConv2DReshapeBatchNormReLUQuantize (
199
204
self , layer_type , normalization_type ):
200
- self ._test_conv_squeeze_bn_activation_transform (
205
+ self ._test_conv_squeeze_or_dense_bn_activation_transform (
201
206
layer_type = layer_type ,
202
207
squeeze_type = 'sepconv1d_squeeze' ,
203
208
normalization_type = normalization_type ,
204
209
activation_type = 'relu' ,
205
- transform_class =
206
- default_8bit_transforms .Conv2DReshapeBatchNormReLUQuantize ,
207
- conv_activation_class = quantize_aware_activation .NoOpActivation ,
208
- normalization_quantize_config_class =
209
- default_8bit_quantize_configs .NoOpQuantizeConfig )
210
+ transform_class = default_8bit_transforms
211
+ .Conv2DReshapeBatchNormReLUQuantize ,
212
+ conv_or_dense_activation_class = quantize_aware_activation .NoOpActivation ,
213
+ normalization_quantize_config_class = default_8bit_quantize_configs
214
+ .NoOpQuantizeConfig )
210
215
211
216
@parameterized .parameters (
212
217
('Conv2D' , 'BatchNormalization' ),
@@ -216,16 +221,63 @@ def testConv2DReshapeBatchNormReLUQuantize(
216
221
)
217
222
def testConv2DReshapeBatchNormActivationQuantize (
218
223
self , layer_type , normalization_type ):
219
- self ._test_conv_squeeze_bn_activation_transform (
224
+ self ._test_conv_squeeze_or_dense_bn_activation_transform (
220
225
layer_type = layer_type ,
221
226
squeeze_type = 'sepconv1d_squeeze' ,
222
227
normalization_type = normalization_type ,
223
228
activation_type = 'act_relu' ,
224
- transform_class =
225
- default_8bit_transforms .Conv2DReshapeBatchNormActivationQuantize ,
226
- conv_activation_class = quantize_aware_activation .NoOpActivation ,
227
- normalization_quantize_config_class =
228
- default_8bit_quantize_configs .NoOpQuantizeConfig )
229
+ transform_class = default_8bit_transforms
230
+ .Conv2DReshapeBatchNormActivationQuantize ,
231
+ conv_or_dense_activation_class = quantize_aware_activation .NoOpActivation ,
232
+ normalization_quantize_config_class = default_8bit_quantize_configs
233
+ .NoOpQuantizeConfig )
234
+
235
+ @parameterized .parameters (
236
+ ('Dense' , 'BatchNormalization' ),
237
+ ('Dense' , 'SyncBatchNormalization' ),
238
+ )
239
+ def testDenseBatchNormQuantize (self , layer_type , normalization_type ):
240
+ self ._test_conv_squeeze_or_dense_bn_activation_transform (
241
+ layer_type = layer_type ,
242
+ squeeze_type = None ,
243
+ normalization_type = normalization_type ,
244
+ activation_type = None ,
245
+ transform_class = default_8bit_transforms .DenseBatchNormQuantize ,
246
+ conv_or_dense_activation_class = quantize_aware_activation .NoOpActivation ,
247
+ normalization_quantize_config_class = default_8bit_quantize_configs
248
+ .Default8BitOutputQuantizeConfig )
249
+
250
+ @parameterized .parameters (
251
+ ('Dense' , 'BatchNormalization' ),
252
+ ('Dense' , 'SyncBatchNormalization' ),
253
+ )
254
+ def testDenseBatchNormReLUQuantize (self , layer_type , normalization_type ):
255
+ self ._test_conv_squeeze_or_dense_bn_activation_transform (
256
+ layer_type = layer_type ,
257
+ squeeze_type = None ,
258
+ normalization_type = normalization_type ,
259
+ activation_type = 'relu' ,
260
+ transform_class = default_8bit_transforms .DenseBatchNormReLUQuantize ,
261
+ conv_or_dense_activation_class = quantize_aware_activation .NoOpActivation ,
262
+ normalization_quantize_config_class = default_8bit_quantize_configs
263
+ .NoOpQuantizeConfig )
264
+
265
+ @parameterized .parameters (
266
+ ('Dense' , 'BatchNormalization' ),
267
+ ('Dense' , 'SyncBatchNormalization' ),
268
+ )
269
+ def testDenseBatchNormActivationQuantize (self , layer_type ,
270
+ normalization_type ):
271
+ self ._test_conv_squeeze_or_dense_bn_activation_transform (
272
+ layer_type = layer_type ,
273
+ squeeze_type = None ,
274
+ normalization_type = normalization_type ,
275
+ activation_type = 'act_relu' ,
276
+ transform_class = default_8bit_transforms
277
+ .DenseBatchNormActivationQuantize ,
278
+ conv_or_dense_activation_class = quantize_aware_activation .NoOpActivation ,
279
+ normalization_quantize_config_class = default_8bit_quantize_configs
280
+ .NoOpQuantizeConfig )
229
281
230
282
@parameterized .named_parameters (
231
283
('padding_valid' , {'padding' : 'valid' }),
0 commit comments