29
29
from tensorflow_model_optimization .python .core .quantization .keras .experimental .default_n_bit import default_n_bit_transforms
30
30
from tensorflow_model_optimization .python .core .quantization .keras .graph_transformations import model_transformer
31
31
from tensorflow_model_optimization .python .core .quantization .keras .layers import conv_batchnorm_test_utils
32
+ from tensorflow_model_optimization .python .core .quantization .keras .layers import dense_batchnorm_test_utils
32
33
33
34
ModelTransformer = model_transformer .ModelTransformer
34
35
35
36
Conv2DModel = conv_batchnorm_test_utils .Conv2DModel
36
37
DepthwiseConv2DModel = conv_batchnorm_test_utils .DepthwiseConv2DModel
38
+ DenseModel = dense_batchnorm_test_utils .DenseModel
37
39
38
40
keras = tf .keras
39
41
@@ -73,21 +75,26 @@ def _get_model(
73
75
post_bn_activation = activation ,
74
76
squeeze_type = squeeze_type ,
75
77
normalization_type = normalization_type )
78
+ elif layer_type == 'Dense' :
79
+ return DenseModel .get_nonfolded_batchnorm_model (
80
+ post_bn_activation = activation , normalization_type = normalization_type )
76
81
77
82
def _get_input_shape (self , layer_type ):
78
83
if layer_type == 'Conv2D' :
79
84
return Conv2DModel .get_batched_input_shape ()
80
85
elif layer_type == 'DepthwiseConv2D' :
81
86
return DepthwiseConv2DModel .get_batched_input_shape ()
87
+ elif layer_type == 'Dense' :
88
+ return DenseModel .get_batched_input_shape ()
82
89
83
- def _test_conv_squeeze_bn_activation_transform (
90
+ def _test_conv_squeeze_or_dense_bn_activation_transform (
84
91
self ,
85
92
layer_type ,
86
93
squeeze_type ,
87
94
normalization_type ,
88
95
activation_type ,
89
96
transform_class ,
90
- conv_activation_class ,
97
+ conv_or_dense_activation_class ,
91
98
normalization_quantize_config_class ):
92
99
model = self ._get_model (layer_type ,
93
100
squeeze_type ,
@@ -107,7 +114,7 @@ def _test_conv_squeeze_bn_activation_transform(
107
114
bn_layer = transformed_model .layers [2 ]
108
115
109
116
self .assertIsInstance (
110
- conv_layer .activation , conv_activation_class )
117
+ conv_layer .activation , conv_or_dense_activation_class )
111
118
self .assertIsInstance (
112
119
updated_metadata .get (bn_layer .name ).get ('quantize_config' ),
113
120
normalization_quantize_config_class )
@@ -123,13 +130,13 @@ def _test_conv_squeeze_bn_activation_transform(
123
130
('DepthwiseConv2D' , 'SyncBatchNormalization' ),
124
131
)
125
132
def testConv2DBatchNormQuantize (self , layer_type , normalization_type ):
126
- self ._test_conv_squeeze_bn_activation_transform (
133
+ self ._test_conv_squeeze_or_dense_bn_activation_transform (
127
134
layer_type = layer_type ,
128
135
squeeze_type = None ,
129
136
normalization_type = normalization_type ,
130
137
activation_type = None ,
131
138
transform_class = default_n_bit_transforms .Conv2DBatchNormQuantize ,
132
- conv_activation_class = quantize_aware_activation .NoOpActivation ,
139
+ conv_or_dense_activation_class = quantize_aware_activation .NoOpActivation ,
133
140
normalization_quantize_config_class =
134
141
n_bit_configs .DefaultNBitOutputQuantizeConfig )
135
142
@@ -140,14 +147,14 @@ def testConv2DBatchNormQuantize(self, layer_type, normalization_type):
140
147
('DepthwiseConv2D' , 'SyncBatchNormalization' ),
141
148
)
142
149
def testConv2DBatchNormReLUQuantize (self , layer_type , normalization_type ):
143
- self ._test_conv_squeeze_bn_activation_transform (
150
+ self ._test_conv_squeeze_or_dense_bn_activation_transform (
144
151
layer_type = layer_type ,
145
152
squeeze_type = None ,
146
153
normalization_type = normalization_type ,
147
154
activation_type = 'relu' ,
148
155
transform_class =
149
156
default_n_bit_transforms .Conv2DBatchNormReLUQuantize ,
150
- conv_activation_class = quantize_aware_activation .NoOpActivation ,
157
+ conv_or_dense_activation_class = quantize_aware_activation .NoOpActivation ,
151
158
normalization_quantize_config_class =
152
159
n_bit_configs .NoOpQuantizeConfig )
153
160
@@ -159,14 +166,14 @@ def testConv2DBatchNormReLUQuantize(self, layer_type, normalization_type):
159
166
)
160
167
def testConv2DBatchNormActivationQuantize (
161
168
self , layer_type , normalization_type ):
162
- self ._test_conv_squeeze_bn_activation_transform (
169
+ self ._test_conv_squeeze_or_dense_bn_activation_transform (
163
170
layer_type = layer_type ,
164
171
squeeze_type = None ,
165
172
normalization_type = normalization_type ,
166
173
activation_type = 'act_relu' ,
167
174
transform_class =
168
175
default_n_bit_transforms .Conv2DBatchNormActivationQuantize ,
169
- conv_activation_class = quantize_aware_activation .NoOpActivation ,
176
+ conv_or_dense_activation_class = quantize_aware_activation .NoOpActivation ,
170
177
normalization_quantize_config_class =
171
178
n_bit_configs .NoOpQuantizeConfig )
172
179
@@ -178,14 +185,14 @@ def testConv2DBatchNormActivationQuantize(
178
185
)
179
186
def testConv2DReshapeBatchNormQuantize (
180
187
self , layer_type , normalization_type ):
181
- self ._test_conv_squeeze_bn_activation_transform (
188
+ self ._test_conv_squeeze_or_dense_bn_activation_transform (
182
189
layer_type = layer_type ,
183
190
squeeze_type = 'sepconv1d_squeeze' ,
184
191
normalization_type = normalization_type ,
185
192
activation_type = False ,
186
193
transform_class =
187
194
default_n_bit_transforms .Conv2DReshapeBatchNormQuantize ,
188
- conv_activation_class = quantize_aware_activation .NoOpActivation ,
195
+ conv_or_dense_activation_class = quantize_aware_activation .NoOpActivation ,
189
196
normalization_quantize_config_class =
190
197
n_bit_configs .DefaultNBitOutputQuantizeConfig )
191
198
@@ -197,14 +204,14 @@ def testConv2DReshapeBatchNormQuantize(
197
204
)
198
205
def testConv2DReshapeBatchNormReLUQuantize (
199
206
self , layer_type , normalization_type ):
200
- self ._test_conv_squeeze_bn_activation_transform (
207
+ self ._test_conv_squeeze_or_dense_bn_activation_transform (
201
208
layer_type = layer_type ,
202
209
squeeze_type = 'sepconv1d_squeeze' ,
203
210
normalization_type = normalization_type ,
204
211
activation_type = 'relu' ,
205
212
transform_class =
206
213
default_n_bit_transforms .Conv2DReshapeBatchNormReLUQuantize ,
207
- conv_activation_class = quantize_aware_activation .NoOpActivation ,
214
+ conv_or_dense_activation_class = quantize_aware_activation .NoOpActivation ,
208
215
normalization_quantize_config_class =
209
216
n_bit_configs .NoOpQuantizeConfig )
210
217
@@ -216,17 +223,64 @@ def testConv2DReshapeBatchNormReLUQuantize(
216
223
)
217
224
def testConv2DReshapeBatchNormActivationQuantize (
218
225
self , layer_type , normalization_type ):
219
- self ._test_conv_squeeze_bn_activation_transform (
226
+ self ._test_conv_squeeze_or_dense_bn_activation_transform (
220
227
layer_type = layer_type ,
221
228
squeeze_type = 'sepconv1d_squeeze' ,
222
229
normalization_type = normalization_type ,
223
230
activation_type = 'act_relu' ,
224
231
transform_class =
225
232
default_n_bit_transforms .Conv2DReshapeBatchNormActivationQuantize ,
226
- conv_activation_class = quantize_aware_activation .NoOpActivation ,
233
+ conv_or_dense_activation_class = quantize_aware_activation .NoOpActivation ,
227
234
normalization_quantize_config_class =
228
235
n_bit_configs .NoOpQuantizeConfig )
229
236
237
+ @parameterized .parameters (
238
+ ('Dense' , 'BatchNormalization' ),
239
+ ('Dense' , 'SyncBatchNormalization' ),
240
+ )
241
+ def testDenseBatchNormQuantize (self , layer_type , normalization_type ):
242
+ self ._test_conv_squeeze_or_dense_bn_activation_transform (
243
+ layer_type = layer_type ,
244
+ squeeze_type = None ,
245
+ normalization_type = normalization_type ,
246
+ activation_type = None ,
247
+ transform_class = default_n_bit_transforms .DenseBatchNormQuantize ,
248
+ conv_or_dense_activation_class = quantize_aware_activation .NoOpActivation ,
249
+ normalization_quantize_config_class = n_bit_configs
250
+ .DefaultNBitOutputQuantizeConfig )
251
+
252
+ @parameterized .parameters (
253
+ ('Dense' , 'BatchNormalization' ),
254
+ ('Dense' , 'SyncBatchNormalization' ),
255
+ )
256
+ def testDenseBatchNormReLUQuantize (self , layer_type , normalization_type ):
257
+ self ._test_conv_squeeze_or_dense_bn_activation_transform (
258
+ layer_type = layer_type ,
259
+ squeeze_type = None ,
260
+ normalization_type = normalization_type ,
261
+ activation_type = 'relu' ,
262
+ transform_class = default_n_bit_transforms .DenseBatchNormReLUQuantize ,
263
+ conv_or_dense_activation_class = quantize_aware_activation .NoOpActivation ,
264
+ normalization_quantize_config_class = n_bit_configs
265
+ .NoOpQuantizeConfig )
266
+
267
+ @parameterized .parameters (
268
+ ('Dense' , 'BatchNormalization' ),
269
+ ('Dense' , 'SyncBatchNormalization' ),
270
+ )
271
+ def testDenseBatchNormActivationQuantize (self , layer_type ,
272
+ normalization_type ):
273
+ self ._test_conv_squeeze_or_dense_bn_activation_transform (
274
+ layer_type = layer_type ,
275
+ squeeze_type = None ,
276
+ normalization_type = normalization_type ,
277
+ activation_type = 'act_relu' ,
278
+ transform_class = default_n_bit_transforms
279
+ .DenseBatchNormActivationQuantize ,
280
+ conv_or_dense_activation_class = quantize_aware_activation .NoOpActivation ,
281
+ normalization_quantize_config_class = n_bit_configs
282
+ .NoOpQuantizeConfig )
283
+
230
284
@parameterized .named_parameters (
231
285
('padding_valid' , {'padding' : 'valid' }),
232
286
('padding_same' , {'padding' : 'same' }),
0 commit comments