@@ -71,6 +71,7 @@ def quantize_scope(*args):
71
71
'QuantizeWrapperV2' : quantize_wrapper .QuantizeWrapperV2 ,
72
72
'QuantizeLayer' : quantize_layer .QuantizeLayer ,
73
73
'OutputOnlyConfig' : quantize_config_mod .OutputOnlyConfig ,
74
+ 'FixedQuantizeConfig' : quantize_config_mod .FixedQuantizeConfig ,
74
75
}
75
76
quantization_objects .update (default_8bit_quantize_registry ._types_dict ()) # pylint: disable=protected-access
76
77
quantization_objects .update (default_n_bit_quantize_registry ._types_dict ()) # pylint: disable=protected-access
@@ -472,3 +473,169 @@ def _quantize(layer): # pylint: disable=missing-docstring
472
473
473
474
return keras .models .clone_model (
474
475
transformed_model , input_tensors = None , clone_function = _quantize )
476
+
477
+
478
+ def _unwrap_first_input_name (inbound_nodes ):
479
+ """Unwrap inbound_nodes three times to get first input name.
480
+
481
+ Args:
482
+ inbound_nodes: A str config that indicates input node. This method assumed
483
+ the inbound_nodes looks like `[[['input', 0, 0, {}]]]`.
484
+
485
+ Returns:
486
+ Returns a str name for the first inbound node.
487
+ """
488
+ current = inbound_nodes
489
+
490
+ for _ in range (3 ):
491
+ if not current :
492
+ return None
493
+ if not isinstance (current , list ):
494
+ return None
495
+ current = current [0 ]
496
+
497
+ if isinstance (current , str ):
498
+ return current
499
+
500
+ return None
501
+
502
+
503
+ def _wrap_fixed_range (
504
+ quantize_config , num_bits , init_min , init_max , narrow_range ):
505
+ config = quantize_config_mod .FixedQuantizeConfig .from_config (
506
+ {'config' : quantize_config ,
507
+ 'num_bits' : num_bits ,
508
+ 'init_min' : init_min ,
509
+ 'init_max' : init_max ,
510
+ 'narrow_range' : narrow_range })
511
+ return tf .keras .utils .serialize_keras_object (config )
512
+
513
+
514
+ def _is_serialized_node_data (nested ):
515
+ # Node data can be of form `[layer_name, node_id, tensor_id]` or
516
+ # `[layer_name, node_id, tensor_id, kwargs]`.
517
+ if (isinstance (nested , list ) and (len (nested ) in [3 , 4 ]) and
518
+ isinstance (nested [0 ], str )):
519
+ return True
520
+ return False
521
+
522
+
523
+ def _nested_to_flatten_node_data_list (nested ):
524
+ """Makes nested node data to flatten node data list."""
525
+ if _is_serialized_node_data (nested ):
526
+ return [nested ]
527
+
528
+ if isinstance (nested , list ):
529
+ return sum (map (_nested_to_flatten_node_data_list , nested ), [])
530
+
531
+ if isinstance (nested , dict ):
532
+ return sum (map (_nested_to_flatten_node_data_list , nested .values ()), [])
533
+
534
+ raise ValueError ('{} is not a supported nested node data.' .format (nested ))
535
+
536
+
537
+ def fix_input_output_range (
538
+ model ,
539
+ num_bits = 8 ,
540
+ input_min = 0.0 ,
541
+ input_max = 1.0 ,
542
+ output_min = 0.0 ,
543
+ output_max = 1.0 ,
544
+ narrow_range = False ):
545
+ """Fix the input and output ranges.
546
+
547
+ Example:
548
+
549
+ ```python
550
+ model = keras.Sequential([
551
+ layers.Dense(10, activation='relu', input_shape=(100,)),
552
+ quantize_annotate_layer(layers.Dense(2, activation='sigmoid'))
553
+ ])
554
+ with quantize.quantize_scope():
555
+ model = quantize_annotate_model(model)
556
+ model = quantize_apply(model)
557
+ model = fix_input_output_range(model, num_bits=4,
558
+ input_min=0, input_max=15,
559
+ output_min=0, output_max=15,
560
+ narrow_range=False)
561
+ ```
562
+
563
+ In certain cases, a desired input/output ranges is known and should not be
564
+ altered during training. To set these values, use the arguments as follows:
565
+
566
+ Args:
567
+ model: A `tf.keras` Sequential or Functional model which has been quantized.
568
+ num_bits: Number of bits for quantization
569
+ input_min: The lower end of quantization interval for the input.
570
+ input_max: The upper end of quantization interval for the input.
571
+ output_min: The lower end of quantization interval for the output.
572
+ output_max: The upper end of quantization interval for the output.
573
+ narrow_range: In case of 8 bits, narrow_range nudges the quantized range
574
+ to be [-127, 127] instead of [-128, 127]. This ensures symmetric
575
+ range has 0 as the centre.
576
+
577
+ Returns:
578
+ Returns a new `tf.keras` model fixed input range set to (input_min,
579
+ input_max) and fixed output range set to (output_min, output_max).
580
+ """
581
+ config = model .get_config ()
582
+ fixed_input_quantizer = quantizers .FixedQuantizer (
583
+ num_bits = num_bits ,
584
+ init_min = input_min ,
585
+ init_max = input_max ,
586
+ narrow_range = narrow_range )
587
+ serialized_fixed_input_quantizer = tf .keras .utils .serialize_keras_object (
588
+ fixed_input_quantizer )
589
+
590
+ if _is_functional_model (model ):
591
+ input_layer_list = _nested_to_flatten_node_data_list (config ['input_layers' ])
592
+ for layer_config in config ['layers' ]:
593
+ input_name = _unwrap_first_input_name (layer_config ['inbound_nodes' ])
594
+ if input_name is None :
595
+ continue
596
+
597
+ for input_layer in input_layer_list :
598
+ if input_name == input_layer [0 ]:
599
+ layer_config ['config' ]['quantizer' ] = serialized_fixed_input_quantizer
600
+ break
601
+
602
+ output_layer_list = _nested_to_flatten_node_data_list (
603
+ config ['output_layers' ])
604
+ for layer_config in config ['layers' ]:
605
+ for output_layer in output_layer_list :
606
+ if layer_config ['config' ]['name' ] == output_layer [0 ]:
607
+ if 'quantize_config' in layer_config ['config' ]:
608
+ layer_config ['config' ]['quantize_config' ] = (
609
+ _wrap_fixed_range (
610
+ layer_config ['config' ]['quantize_config' ],
611
+ num_bits = num_bits ,
612
+ init_min = output_min ,
613
+ init_max = output_max ,
614
+ narrow_range = narrow_range ))
615
+ break
616
+
617
+ model = keras .Model .from_config (config )
618
+ else :
619
+ if (len (config ['layers' ]) < 1 or
620
+ config ['layers' ][1 ]['class_name' ] != 'QuantizeLayer' ):
621
+ raise ValueError ('`model` should be already quantized.' )
622
+ config ['layers' ][1 ]['config' ][
623
+ 'quantizer' ] = serialized_fixed_input_quantizer
624
+ if 'quantize_config' in config ['layers' ][- 1 ]['config' ]:
625
+ config ['layers' ][- 1 ]['config' ]['quantize_config' ] = (
626
+ _wrap_fixed_range (
627
+ config ['layers' ][- 1 ]['config' ]['quantize_config' ],
628
+ num_bits = num_bits ,
629
+ init_min = output_min ,
630
+ init_max = output_max ,
631
+ narrow_range = narrow_range ))
632
+
633
+ model = keras .Sequential .from_config (config )
634
+
635
+ return model
636
+
637
+
638
+ def _is_functional_model (model ):
639
+ return (isinstance (model , keras .Model )
640
+ and not isinstance (model , keras .Sequential )
641
+ and model ._is_graph_network ) # pylint: disable=protected-access
0 commit comments