Skip to content

Commit 92ba30b

Browse files
Xharktensorflower-gardener
authored andcommitted
Removed potentially accumulated errors from testSeparableConvQuantize_
PiperOrigin-RevId: 379835554
1 parent 812ea04 commit 92ba30b

File tree

1 file changed

+14
-12
lines changed

1 file changed

+14
-12
lines changed

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms_test.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -290,9 +290,8 @@ def testSeparableConv1DQuantize_(self, kwargs):
290290
@parameterized.named_parameters(
291291
('padding_valid', {'padding': 'valid'}),
292292
('padding_same', {'padding': 'same'}),
293-
# TODO(b/186666265): tighten the tolerance to 1e-5.
294293
('padding_same_dilation_2',
295-
{'padding': 'same', 'dilation_rate': 2}, 0.19),
294+
{'padding': 'same', 'dilation_rate': 2}),
296295
('strides', {'strides': 2}),
297296
('dilation_rate', {'dilation_rate': 2}),
298297
('depth_multiplier', {'depth_multiplier': 2}),
@@ -307,7 +306,7 @@ def testSeparableConv1DQuantize_(self, kwargs):
307306
'pointwise_constraint': tf.keras.constraints.min_max_norm(0., 2.),
308307
'bias_constraint': tf.keras.constraints.unit_norm()})
309308
)
310-
def testSeparableConvQuantize_(self, kwargs, tolerance=1e-5):
309+
def testSeparableConvQuantize_(self, kwargs):
311310
kwargs['filters'] = 2
312311
kwargs['kernel_size'] = 3
313312
num_samples = 2
@@ -338,17 +337,20 @@ def testSeparableConvQuantize_(self, kwargs, tolerance=1e-5):
338337

339338
# Ensure model is equivalent, and training results are the same.
340339
sepconv_model.compile(loss='categorical_crossentropy', optimizer='sgd')
341-
sepconv_model.fit(x, y, epochs=100)
342340
transformed_model.compile(loss='categorical_crossentropy', optimizer='sgd')
343-
transformed_model.fit(x, y, epochs=100)
344341

345-
# Over a long training cycle with constraints and regularizers, the model
346-
# can build very minute differences.
347-
self.assertAllClose(
348-
sepconv_model.predict(x),
349-
transformed_model.predict(x),
350-
atol=tolerance,
351-
rtol=tolerance)
342+
epochs = 100
343+
for _ in range(epochs):
344+
sepconv_model.fit(x, y, epochs=1, verbose=2)
345+
transformed_model.fit(x, y, epochs=1, verbose=2)
346+
self.assertAllClose(
347+
sepconv_model.get_weights(),
348+
transformed_model.get_weights())
349+
# To prevent accumulated numerical errors.
350+
transformed_model.set_weights(sepconv_model.get_weights())
351+
self.assertAllClose(
352+
sepconv_model.predict(x),
353+
transformed_model.predict(x))
352354

353355
# TODO(pulkitb): Add individual tests for the following transforms.
354356
# Conv2DReshapeBatchNormQuantize, Conv2DReshapeBatchNormReLUQuantize

0 commit comments

Comments
 (0)