Skip to content

Commit b6a97f2

Browse files
alanchiaotensorflower-gardener
authored andcommitted
Improve or add error messages for nested models and subclassed models.
Context: #292. PiperOrigin-RevId: 303861624
1 parent a013182 commit b6a97f2

File tree

6 files changed

+76
-0
lines changed

6 files changed

+76
-0
lines changed

tensorflow_model_optimization/python/core/quantization/keras/quantize.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,12 @@ def quantize_model(to_quantize):
104104
'You passed an instance of type: {input}.'.format(
105105
input=to_quantize.__class__.__name__))
106106

107+
if not isinstance(to_quantize, keras.Sequential) \
108+
and not to_quantize._is_graph_network: # pylint: disable=protected-access
109+
raise ValueError(
110+
'`to_quantize` can only either be a tf.keras Sequential or '
111+
'Functional model.')
112+
107113
annotated_model = quantize_annotate_model(to_quantize)
108114
return quantize_apply(annotated_model)
109115

@@ -149,11 +155,23 @@ def quantize_annotate_model(to_annotate):
149155
'You passed an instance of type: {input}.'.format(
150156
input=to_annotate.__class__.__name__))
151157

158+
if not isinstance(to_annotate, keras.Sequential) \
159+
and not to_annotate._is_graph_network: # pylint: disable=protected-access
160+
raise ValueError(
161+
'`to_annotate` can only either be a tf.keras Sequential or '
162+
'Functional model.')
163+
152164
def _add_quant_wrapper(layer):
165+
"""Add annotation wrapper."""
153166
# Already annotated layer. No need to wrap.
154167
if isinstance(layer, quantize_annotate_mod.QuantizeAnnotate):
155168
return layer
156169

170+
if isinstance(layer, tf.keras.Model):
171+
raise ValueError(
172+
'Quantizing a tf.keras Model inside another tf.keras Model is not supported.'
173+
)
174+
157175
return quantize_annotate_mod.QuantizeAnnotate(layer)
158176

159177
return keras.models.clone_model(

tensorflow_model_optimization/python/core/quantization/keras/quantize_annotate.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,17 @@ def __init__(self, layer, quantize_config=None, **kwargs):
5252
"""
5353
super(QuantizeAnnotate, self).__init__(layer, **kwargs)
5454

55+
if layer is None:
56+
raise ValueError('`layer` cannot be None.')
57+
58+
# Check against keras.Model since it is an instance of keras.layers.Layer.
59+
if not isinstance(layer, tf.keras.layers.Layer) or isinstance(
60+
layer, tf.keras.Model):
61+
raise ValueError(
62+
'`layer` can only be a `tf.keras.layers.Layer` instance. '
63+
'You passed an instance of type: {input}.'.format(
64+
input=layer.__class__.__name__))
65+
5566
self.quantize_config = quantize_config
5667

5768
self._track_trackable(layer, name='layer')

tensorflow_model_optimization/python/core/quantization/keras/quantize_annotate_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,12 @@ def testSerializationQuantizeAnnotate(self):
8989

9090
self.assertEqual(wrapper_from_config.get_config(), wrapper.get_config())
9191

92+
def testQuantizeAnnotate_FailsWithModel(self):
93+
layer = keras.layers.Dense(5, activation='relu', input_shape=(10,))
94+
model = keras.Sequential([layer])
95+
96+
with self.assertRaises(ValueError):
97+
quantize_annotate.QuantizeAnnotate(model)
9298

9399
if __name__ == '__main__':
94100
tf.test.main()

tensorflow_model_optimization/python/core/quantization/keras/quantize_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,20 @@ def testQuantizeAnnotateModel_RemovesOptimizer(self):
165165
annotated_model = quantize_annotate_model(model)
166166
self.assertIsNone(annotated_model.optimizer)
167167

168+
def testQuantizeAnnotateModel_FailsWithSubclassedModel(self):
169+
class MyModel(keras.Model):
170+
def call(self, inputs, training=None, mask=None): # pylint: disable=g-wrong-blank-lines
171+
return inputs
172+
173+
with self.assertRaises(ValueError):
174+
quantize_annotate_model(MyModel())
175+
176+
def testQuantizeAnnotateModel_FailsWithNestedModels(self):
177+
with self.assertRaises(ValueError):
178+
quantize_annotate_model(
179+
keras.Sequential(
180+
[keras.Sequential([keras.layers.Dense(10, input_shape=(2,))])]))
181+
168182

169183
class QuantizeApplyTest(tf.test.TestCase):
170184

@@ -411,6 +425,14 @@ def testQuantizeApply_RemovesOptimizer(self):
411425
quantized_model = quantize_apply(annotated_model)
412426
self.assertIsNone(quantized_model.optimizer)
413427

428+
def testQuantizeApply_RunsWhenNestedModelNotAnnotated(self):
429+
annotated_model = keras.Sequential([
430+
keras.Sequential([keras.layers.Dense(10, input_shape=(2,))]),
431+
quantize_annotate_layer(keras.layers.Dense(10)),
432+
])
433+
434+
quantize_apply(annotated_model)
435+
414436

415437
if __name__ == '__main__':
416438
tf.test.main()

tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,16 @@ def __init__(self, layer, quantize_config, **kwargs):
4747
quantize_config: `QuantizeConfig` to quantize layer.
4848
**kwargs: Additional keyword arguments to be passed to the keras layer.
4949
"""
50+
if layer is None:
51+
raise ValueError('`layer` cannot be None.')
52+
53+
# Check against keras.Model since it is an instance of keras.layers.Layer.
54+
if not isinstance(layer, tf.keras.layers.Layer) or isinstance(
55+
layer, tf.keras.Model):
56+
raise ValueError(
57+
'`layer` can only be a `tf.keras.layers.Layer` instance. '
58+
'You passed an instance of type: {input}.'.format(
59+
input=layer.__class__.__name__))
5060

5161
if quantize_config is None:
5262
raise ValueError('quantize_config cannot be None. It is needed to '

tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,15 @@ def testSerializationQuantizeWrapper(self):
180180

181181
self.assertEqual(wrapper_from_config.get_config(), wrapper.get_config())
182182

183+
def testQuantizeWrapper_FailsWithModel(self):
184+
layer = keras.layers.Dense(5, activation='relu', input_shape=(10,))
185+
model = keras.Sequential([layer])
186+
187+
with self.assertRaises(ValueError):
188+
QuantizeWrapper(
189+
model,
190+
quantize_config=self.quantize_registry.get_quantize_config(layer))
191+
183192
# TODO(pulkitb): Add test to ensure weights are also preserved.
184193

185194

0 commit comments

Comments
 (0)