@@ -288,12 +288,44 @@ def _assert_model_quantized(
288
288
zip (annotated_model .layers , quantized_model .layers ):
289
289
290
290
if not isinstance (layer_annotated , QuantizeAnnotate ):
291
- self .assertNotIsInstance (layer_quantized , QuantizeWrapper )
291
+ # Possibly wrapped for input quantization.
292
+ if isinstance (layer_quantized , QuantizeWrapper ):
293
+ self .assertLen (layer_annotated ._outbound_nodes , 1 )
294
+ self .assertIsInstance (
295
+ layer_annotated ._outbound_nodes [0 ].outbound_layer ,
296
+ QuantizeAnnotate )
297
+
298
+ # Ensure that only outputs are quantized.
299
+ self .assertFalse (
300
+ layer_quantized .quantize_config .get_weights_and_quantizers (
301
+ layer_quantized .layer ))
292
302
continue
293
303
294
304
self ._assert_layer_quantized (
295
305
layer_annotated , layer_quantized , exclude_keys )
296
306
307
+ def _assert_nonannotated_input_layer_quantized (
308
+ self , quantized_model , layer_index ):
309
+ output_quantized_layer = quantized_model .layers [layer_index ]
310
+ self .assertIsInstance (output_quantized_layer , QuantizeWrapper )
311
+ output_quantized_config = output_quantized_layer .quantize_config
312
+ default_quantized_config = output_quantized_config .get_config ()[
313
+ 'quantize_config' ]
314
+ self .assertIsInstance (
315
+ default_quantized_config ,
316
+ default_8bit_quantize_registry .Default8BitQuantizeConfig )
317
+ self .assertFalse (output_quantized_config .get_weights_and_quantizers (
318
+ output_quantized_layer .layer ))
319
+ self .assertTrue (default_quantized_config .get_weights_and_quantizers (
320
+ output_quantized_layer .layer ))
321
+ output_quantizers = output_quantized_config .get_output_quantizers (
322
+ output_quantized_layer .layer )
323
+ default_output_quantizers = default_quantized_config .get_output_quantizers (
324
+ output_quantized_layer .layer )
325
+ for output_quantizer , default_quantizer in zip (output_quantizers ,
326
+ default_output_quantizers ):
327
+ self .assertEqual (output_quantizer , default_quantizer )
328
+
297
329
# quantize_apply Tests
298
330
299
331
class CustomLayer (keras .layers .Dense ):
@@ -379,6 +411,18 @@ def testAppliesQuantizationToAnnotatedModel_Sequential(self):
379
411
380
412
self ._assert_model_quantized (model , quantized_model , ['activation' ])
381
413
414
+ def testAppliesQuantizationToInputsToAnnotatedModel_Sequential (self ):
415
+ model = keras .Sequential ([
416
+ keras .layers .Conv2D (32 , 5 , input_shape = (28 , 28 , 1 ), activation = 'relu' ),
417
+ keras .layers .Dense (10 , activation = 'relu' ),
418
+ quantize_annotate_layer (keras .layers .Dense (5 , activation = 'softmax' )),
419
+ ])
420
+ quantized_model = quantize_apply (model )
421
+ self ._assert_model_quantized (model , quantized_model , ['activation' ])
422
+ # Test that Dense layer has output only quantization config.
423
+ self ._assert_nonannotated_input_layer_quantized (
424
+ quantized_model , layer_index = 1 )
425
+
382
426
def testAppliesQuantizationToAnnotatedModel_PreservesBuiltState (self ):
383
427
model = keras_test_utils .build_simple_dense_model ()
384
428
annotated_model = quantize_annotate_model (model )
@@ -404,6 +448,20 @@ def testAppliesQuantizationToAnnotatedModel_Functional(self):
404
448
405
449
self ._assert_model_quantized (model , quantized_model , ['activation' ])
406
450
451
+ def testAppliesQuantizationToInputsToAnnotatedModel_Functional (self ):
452
+ inputs = keras .Input (shape = (28 , 28 , 1 ))
453
+ x = keras .layers .Conv2D (32 , 5 , activation = 'relu' )(inputs )
454
+ x = keras .layers .Dense (10 , activation = 'relu' )(x )
455
+ results = quantize_annotate_layer (
456
+ keras .layers .Dense (5 , activation = 'softmax' ))(
457
+ x )
458
+ model = keras .Model (inputs = inputs , outputs = results )
459
+ quantized_model = quantize_apply (model )
460
+ self ._assert_model_quantized (model , quantized_model , ['activation' ])
461
+ # Test that Dense layer has output only quantization config.
462
+ self ._assert_nonannotated_input_layer_quantized (
463
+ quantized_model , layer_index = 2 )
464
+
407
465
def testDoesNotQuantizeInputLayer_OutboundLayerNotQuantized (self ):
408
466
model = self ._get_simple_functional_model ()
409
467
0 commit comments