Skip to content

Commit e64880a

Browse files
Xharktensorflower-gardener
authored andcommitted
Supports additional argument for QuantizeAnnotate.call.
PiperOrigin-RevId: 368379016
1 parent 92486bf commit e64880a

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

tensorflow_model_optimization/python/core/quantization/keras/quantize_annotate.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,12 @@ def __init__(self, layer, quantize_config=None, **kwargs):
8787
hasattr(layer, '_batch_input_shape')):
8888
self._batch_input_shape = self.layer._batch_input_shape # pylint: disable=protected-access
8989

90-
def call(self, inputs, training=None):
91-
return self.layer.call(inputs)
90+
def call(self, *args, **kwargs):
91+
# TODO(b/185306646): Explicitly wants to pass training argument for the
92+
# layer. Currently, we remove training argument.
93+
if 'training' in kwargs:
94+
del kwargs['training']
95+
return self.layer.call(*args, **kwargs)
9296

9397
def get_config(self):
9498
base_config = super(QuantizeAnnotate, self).get_config()

0 commit comments

Comments
 (0)