Skip to content

Commit 7c93a20

Browse files
nutsiepullytensorflower-gardener
authored andcommitted
Include AllValuesQuantizer in default_8bit scheme
PiperOrigin-RevId: 320092094
1 parent baf1bb8 commit 7c93a20

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -269,9 +269,8 @@ def pattern(self):
269269
return LayerPattern('InputLayer')
270270

271271
def replacement(self, match_layer):
272-
# TODO(pulkitb): Replace quantizer with InputLayer specific quantizer.
273272
quant_layer = quantize_layer.QuantizeLayer(
274-
quantizers.MovingAverageQuantizer(
273+
quantizers.AllValuesQuantizer(
275274
num_bits=8, per_axis=False, symmetric=False, narrow_range=False))
276275
layer_config = keras.layers.serialize(quant_layer)
277276
layer_config['name'] = quant_layer.name
@@ -285,7 +284,8 @@ def replacement(self, match_layer):
285284
def custom_objects(self):
286285
return {
287286
'QuantizeLayer': quantize_layer.QuantizeLayer,
288-
'MovingAverageQuantizer': quantizers.MovingAverageQuantizer
287+
'MovingAverageQuantizer': quantizers.MovingAverageQuantizer,
288+
'AllValuesQuantizer': quantizers.AllValuesQuantizer
289289
}
290290

291291

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ def testAddsQuantizeLayerAfterInputLayer(self):
264264
layer_after_input,
265265
quantize_layer.QuantizeLayer)
266266
self.assertIsInstance(
267-
layer_after_input.quantizer, quantizers.MovingAverageQuantizer)
267+
layer_after_input.quantizer, quantizers.AllValuesQuantizer)
268268

269269
def testConcatTransform(self):
270270
r"""Tests the Concat Transform.

0 commit comments

Comments
 (0)