@@ -288,12 +288,44 @@ def _assert_model_quantized(
288288 zip (annotated_model .layers , quantized_model .layers ):
289289
290290 if not isinstance (layer_annotated , QuantizeAnnotate ):
291- self .assertNotIsInstance (layer_quantized , QuantizeWrapper )
291+ # Possibly wrapped for input quantization.
292+ if isinstance (layer_quantized , QuantizeWrapper ):
293+ self .assertLen (layer_annotated ._outbound_nodes , 1 )
294+ self .assertIsInstance (
295+ layer_annotated ._outbound_nodes [0 ].outbound_layer ,
296+ QuantizeAnnotate )
297+
298+ # Ensure that only outputs are quantized.
299+ self .assertFalse (
300+ layer_quantized .quantize_config .get_weights_and_quantizers (
301+ layer_quantized .layer ))
292302 continue
293303
294304 self ._assert_layer_quantized (
295305 layer_annotated , layer_quantized , exclude_keys )
296306
307+ def _assert_nonannotated_input_layer_quantized (
308+ self , quantized_model , layer_index ):
309+ output_quantized_layer = quantized_model .layers [layer_index ]
310+ self .assertIsInstance (output_quantized_layer , QuantizeWrapper )
311+ output_quantized_config = output_quantized_layer .quantize_config
312+ default_quantized_config = output_quantized_config .get_config ()[
313+ 'quantize_config' ]
314+ self .assertIsInstance (
315+ default_quantized_config ,
316+ default_8bit_quantize_registry .Default8BitQuantizeConfig )
317+ self .assertFalse (output_quantized_config .get_weights_and_quantizers (
318+ output_quantized_layer .layer ))
319+ self .assertTrue (default_quantized_config .get_weights_and_quantizers (
320+ output_quantized_layer .layer ))
321+ output_quantizers = output_quantized_config .get_output_quantizers (
322+ output_quantized_layer .layer )
323+ default_output_quantizers = default_quantized_config .get_output_quantizers (
324+ output_quantized_layer .layer )
325+ for output_quantizer , default_quantizer in zip (output_quantizers ,
326+ default_output_quantizers ):
327+ self .assertEqual (output_quantizer , default_quantizer )
328+
297329 # quantize_apply Tests
298330
299331 class CustomLayer (keras .layers .Dense ):
@@ -379,6 +411,18 @@ def testAppliesQuantizationToAnnotatedModel_Sequential(self):
379411
380412 self ._assert_model_quantized (model , quantized_model , ['activation' ])
381413
414+ def testAppliesQuantizationToInputsToAnnotatedModel_Sequential (self ):
415+ model = keras .Sequential ([
416+ keras .layers .Conv2D (32 , 5 , input_shape = (28 , 28 , 1 ), activation = 'relu' ),
417+ keras .layers .Dense (10 , activation = 'relu' ),
418+ quantize_annotate_layer (keras .layers .Dense (5 , activation = 'softmax' )),
419+ ])
420+ quantized_model = quantize_apply (model )
421+ self ._assert_model_quantized (model , quantized_model , ['activation' ])
422+ # Test that Dense layer has output only quantization config.
423+ self ._assert_nonannotated_input_layer_quantized (
424+ quantized_model , layer_index = 1 )
425+
382426 def testAppliesQuantizationToAnnotatedModel_PreservesBuiltState (self ):
383427 model = keras_test_utils .build_simple_dense_model ()
384428 annotated_model = quantize_annotate_model (model )
@@ -404,6 +448,20 @@ def testAppliesQuantizationToAnnotatedModel_Functional(self):
404448
405449 self ._assert_model_quantized (model , quantized_model , ['activation' ])
406450
451+ def testAppliesQuantizationToInputsToAnnotatedModel_Functional (self ):
452+ inputs = keras .Input (shape = (28 , 28 , 1 ))
453+ x = keras .layers .Conv2D (32 , 5 , activation = 'relu' )(inputs )
454+ x = keras .layers .Dense (10 , activation = 'relu' )(x )
455+ results = quantize_annotate_layer (
456+ keras .layers .Dense (5 , activation = 'softmax' ))(
457+ x )
458+ model = keras .Model (inputs = inputs , outputs = results )
459+ quantized_model = quantize_apply (model )
460+ self ._assert_model_quantized (model , quantized_model , ['activation' ])
461+ # Test that Dense layer has output only quantization config.
462+ self ._assert_nonannotated_input_layer_quantized (
463+ quantized_model , layer_index = 2 )
464+
407465 def testDoesNotQuantizeInputLayer_OutboundLayerNotQuantized (self ):
408466 model = self ._get_simple_functional_model ()
409467
0 commit comments