Skip to content

Commit 0b2c262

Browse files
daverimtensorflower-gardener
authored andcommitted
Add better error messages when attempting to quantize Lambda layers.
PiperOrigin-RevId: 415179482
1 parent 1bec520 commit 0b2c262

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

tensorflow_model_optimization/python/core/quantization/keras/quantize.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515
"""Quantization API functions for tf.keras models."""
16+
import warnings
1617

1718
import tensorflow as tf
1819

@@ -177,6 +178,9 @@ def quantize_annotate_model(to_annotate):
177178
New tf.keras model with each layer in the model wrapped with
178179
`QuantizeAnnotate`. The new model preserves weights from the original
179180
model.
181+
182+
Raises:
183+
ValueError: if the model cannot be annotated.
180184
"""
181185
if to_annotate is None:
182186
raise ValueError('`to_annotate` cannot be None')
@@ -200,6 +204,15 @@ def _add_quant_wrapper(layer):
200204
if isinstance(layer, quantize_annotate_mod.QuantizeAnnotate):
201205
return layer
202206

207+
if isinstance(layer, tf.keras.layers.Lambda):
208+
warnings.warn(
209+
'Lambda layers are not supported by automatic model annotation '
210+
'because the internal functionality cannot always be determined by '
211+
'serialization alone. We recommend that you make a custom layer '
212+
'and add a custom QuantizeConfig for it instead. This layer will not '
213+
'be quantized which may lead to unexpected results.')
214+
return layer
215+
203216
if isinstance(layer, tf.keras.Model):
204217
raise ValueError(
205218
'Quantizing a tf.keras Model inside another tf.keras Model is not supported.'

tensorflow_model_optimization/python/core/quantization/keras/quantize_test.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,21 @@ def testQuantizeAnnotateModel_FailsWithNestedModels(self):
187187
keras.Sequential(
188188
[keras.Sequential([keras.layers.Dense(10, input_shape=(2,))])]))
189189

190+
def testQuantizeAnnotateModel_SkipsLambda(self):
191+
model = keras.Sequential([
192+
keras.layers.Dense(10, input_shape=(5,)),
193+
keras.layers.Dropout(0.4),
194+
keras.layers.Lambda(lambda x: x + 1.0)
195+
])
196+
with self.assertWarns(Warning):
197+
annotated_model = quantize_annotate_model(model)
198+
199+
self._assertWrappedLayer(annotated_model.layers[0])
200+
self._assertWrappedLayer(annotated_model.layers[1])
201+
self.assertIsInstance(annotated_model.layers[2], keras.layers.Lambda)
202+
inputs = np.random.rand(1, 5)
203+
self.assertAllEqual(model.predict(inputs), annotated_model.predict(inputs))
204+
190205

191206
class QuantizeApplyTest(tf.test.TestCase):
192207

0 commit comments

Comments
 (0)