Skip to content

Commit b47e0d9

Browse files
nutsiepullytensorflower-gardener
authored andcommitted
Propogate training arg to inner layer in QuantizeWrapper
PiperOrigin-RevId: 321031035
1 parent 16f4c6b commit b47e0d9

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from __future__ import division
2727
from __future__ import print_function
2828

29+
import inspect
2930
import tensorflow as tf
3031

3132
# TODO(b/139939526): move to public API.
@@ -159,7 +160,11 @@ def call(self, inputs, training=None):
159160
self.quantize_config.set_quantize_activations(self.layer,
160161
self._quantize_activations)
161162

162-
outputs = self.layer.call(inputs)
163+
args = inspect.getfullargspec(self.layer.call).args
164+
if 'training' in args:
165+
outputs = self.layer.call(inputs, training=training)
166+
else:
167+
outputs = self.layer.call(inputs)
163168

164169
if not self._output_quantizers:
165170
return outputs

0 commit comments

Comments
 (0)