Skip to content

Commit 3221a86

Browse files
rino20tensorflower-gardener
authored andcommitted
Explicitly pass training argument to the layer in QAT module.
PiperOrigin-RevId: 370377067
1 parent 006e377 commit 3221a86

File tree

2 files changed

+23
-4
lines changed

2 files changed

+23
-4
lines changed

tensorflow_model_optimization/python/core/quantization/keras/quantize_annotate.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from __future__ import division
2020
from __future__ import print_function
2121

22+
import inspect
23+
2224
import tensorflow as tf
2325

2426
deserialize_keras_object = tf.keras.utils.deserialize_keras_object
@@ -88,9 +90,11 @@ def __init__(self, layer, quantize_config=None, **kwargs):
8890
self._batch_input_shape = self.layer._batch_input_shape # pylint: disable=protected-access
8991

9092
def call(self, *args, **kwargs):
91-
# TODO(b/185306646): Explicitly wants to pass training argument for the
92-
# layer. Currently, we remove training argument.
93-
if 'training' in kwargs:
93+
arg = inspect.getfullargspec(self.layer.call).args
94+
95+
# Do not propagate the training bool to the underlying layer if it doesn't
96+
# accepts the training bool.
97+
if 'training' not in arg and 'training' in kwargs:
9498
del kwargs['training']
9599
return self.layer.call(*args, **kwargs)
96100

tensorflow_model_optimization/python/core/quantization/keras/quantize_annotate_test.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
from __future__ import print_function
2020

2121
import numpy as np
22-
2322
import tensorflow as tf
2423

24+
from tensorflow.python.keras.engine.base_layer import Layer
2525
from tensorflow_model_optimization.python.core.quantization.keras import quantize_annotate
2626
from tensorflow_model_optimization.python.core.quantization.keras import quantize_config as quantize_config_mod
2727

@@ -52,6 +52,21 @@ def get_output_quantizers(self, layer):
5252
def get_config(self):
5353
return {}
5454

55+
def testAnnotateLayerCallPassesTraningBoolean(self):
56+
57+
class MockLayer(Layer):
58+
self.training = None
59+
60+
def call(self, training=None):
61+
self.training = training
62+
63+
layer = MockLayer()
64+
wrapper = quantize_annotate.QuantizeAnnotate(layer=layer)
65+
wrapper.call(training=True)
66+
self.assertTrue(layer.training)
67+
wrapper.call(training=False)
68+
self.assertFalse(layer.training)
69+
5570
def testAnnotatesKerasLayer(self):
5671
layer = keras.layers.Dense(5, activation='relu', input_shape=(10,))
5772
model = keras.Sequential([layer])

0 commit comments

Comments
 (0)