@@ -290,9 +290,8 @@ def testSeparableConv1DQuantize_(self, kwargs):
290
290
@parameterized .named_parameters (
291
291
('padding_valid' , {'padding' : 'valid' }),
292
292
('padding_same' , {'padding' : 'same' }),
293
- # TODO(b/186666265): tighten the tolerance to 1e-5.
294
293
('padding_same_dilation_2' ,
295
- {'padding' : 'same' , 'dilation_rate' : 2 }, 0.19 ),
294
+ {'padding' : 'same' , 'dilation_rate' : 2 }),
296
295
('strides' , {'strides' : 2 }),
297
296
('dilation_rate' , {'dilation_rate' : 2 }),
298
297
('depth_multiplier' , {'depth_multiplier' : 2 }),
@@ -307,7 +306,7 @@ def testSeparableConv1DQuantize_(self, kwargs):
307
306
'pointwise_constraint' : tf .keras .constraints .min_max_norm (0. , 2. ),
308
307
'bias_constraint' : tf .keras .constraints .unit_norm ()})
309
308
)
310
- def testSeparableConvQuantize_ (self , kwargs , tolerance = 1e-5 ):
309
+ def testSeparableConvQuantize_ (self , kwargs ):
311
310
kwargs ['filters' ] = 2
312
311
kwargs ['kernel_size' ] = 3
313
312
num_samples = 2
@@ -338,17 +337,20 @@ def testSeparableConvQuantize_(self, kwargs, tolerance=1e-5):
338
337
339
338
# Ensure model is equivalent, and training results are the same.
340
339
sepconv_model .compile (loss = 'categorical_crossentropy' , optimizer = 'sgd' )
341
- sepconv_model .fit (x , y , epochs = 100 )
342
340
transformed_model .compile (loss = 'categorical_crossentropy' , optimizer = 'sgd' )
343
- transformed_model .fit (x , y , epochs = 100 )
344
341
345
- # Over a long training cycle with constraints and regularizers, the model
346
- # can build very minute differences.
347
- self .assertAllClose (
348
- sepconv_model .predict (x ),
349
- transformed_model .predict (x ),
350
- atol = tolerance ,
351
- rtol = tolerance )
342
+ epochs = 100
343
+ for _ in range (epochs ):
344
+ sepconv_model .fit (x , y , epochs = 1 , verbose = 2 )
345
+ transformed_model .fit (x , y , epochs = 1 , verbose = 2 )
346
+ self .assertAllClose (
347
+ sepconv_model .get_weights (),
348
+ transformed_model .get_weights ())
349
+ # To prevent accumulated numerical errors.
350
+ transformed_model .set_weights (sepconv_model .get_weights ())
351
+ self .assertAllClose (
352
+ sepconv_model .predict (x ),
353
+ transformed_model .predict (x ))
352
354
353
355
# TODO(pulkitb): Add individual tests for the following transforms.
354
356
# Conv2DReshapeBatchNormQuantize, Conv2DReshapeBatchNormReLUQuantize
0 commit comments