|
26 | 26 | from tensorflow_model_optimization.python.core.quantization.keras import quantize_layer
|
27 | 27 | from tensorflow_model_optimization.python.core.quantization.keras import quantizers
|
28 | 28 | from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_configs
|
| 29 | +from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_registry |
29 | 30 | from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_transforms
|
30 | 31 | from tensorflow_model_optimization.python.core.quantization.keras.graph_transformations import model_transformer
|
31 | 32 | from tensorflow_model_optimization.python.core.quantization.keras.layers import conv_batchnorm_test_utils
|
@@ -576,6 +577,134 @@ def testConcatMultipleLevels(self):
|
576 | 577 | default_8bit_quantize_configs.Default8BitOutputQuantizeConfig)
|
577 | 578 | self.assertNotEmpty(quantize_config.get_output_quantizers(None))
|
578 | 579 |
|
| 580 | + def testConcatActivationTransform(self): |
| 581 | + r"""Tests the Concat Transform. |
| 582 | +
|
| 583 | + Input Input |
| 584 | + / \ |
| 585 | + Relu Relu |
| 586 | + \ / |
| 587 | + Concat |
| 588 | +
|
| 589 | + The Transform should ensure both the output FakeQuants are disabled, |
| 590 | + and only a FakeQuant after Concat is present. |
| 591 | + """ |
| 592 | + relu_1 = keras.layers.Activation('relu') |
| 593 | + relu_2 = keras.layers.Activation('relu') |
| 594 | + concat = keras.layers.Concatenate() |
| 595 | + |
| 596 | + inp1 = keras.layers.Input((2,)) |
| 597 | + inp2 = keras.layers.Input((2,)) |
| 598 | + x1 = relu_1(inp1) |
| 599 | + x2 = relu_2(inp2) |
| 600 | + x = concat([x1, x2]) |
| 601 | + model = keras.Model([inp1, inp2], x) |
| 602 | + |
| 603 | + layer_metadata = { |
| 604 | + # dense_1 has an existing quantize_config. |
| 605 | + relu_1.name: { |
| 606 | + 'quantize_config': |
| 607 | + (default_8bit_quantize_registry |
| 608 | + .Default8BitActivationQuantizeConfig()) |
| 609 | + }, |
| 610 | + relu_2.name: { |
| 611 | + 'quantize_config': |
| 612 | + (default_8bit_quantize_registry |
| 613 | + .Default8BitActivationQuantizeConfig()) |
| 614 | + } |
| 615 | + } |
| 616 | + _, updated_metadata = ModelTransformer( |
| 617 | + model, [default_8bit_transforms.ConcatTransform()], |
| 618 | + layer_metadata=layer_metadata).transform() |
| 619 | + |
| 620 | + concat_quantize_config = updated_metadata.get( |
| 621 | + concat.name).get('quantize_config') |
| 622 | + # Concat should quantize the output. |
| 623 | + self.assertIsInstance( |
| 624 | + concat_quantize_config, |
| 625 | + default_8bit_quantize_configs.Default8BitOutputQuantizeConfig) |
| 626 | + self.assertNotEmpty(concat_quantize_config.get_output_quantizers(None)) |
| 627 | + |
| 628 | + relu_1_quantize_config = updated_metadata.get( |
| 629 | + relu_1.name).get('quantize_config') |
| 630 | + # The existing quantize_config should do nothing for outputs. |
| 631 | + self.assertIsInstance( |
| 632 | + relu_1_quantize_config, |
| 633 | + default_8bit_quantize_registry.Default8BitActivationQuantizeConfig) |
| 634 | + self.assertEmpty(relu_1_quantize_config.get_output_quantizers(None)) |
| 635 | + self.assertFalse(relu_1_quantize_config.quantize_output) |
| 636 | + |
| 637 | + relu_2_quantize_config = updated_metadata.get( |
| 638 | + relu_2.name).get('quantize_config') |
| 639 | + # The quantize_config from registry should do nothing at output. |
| 640 | + self.assertIsInstance( |
| 641 | + relu_1_quantize_config, |
| 642 | + default_8bit_quantize_registry.Default8BitActivationQuantizeConfig) |
| 643 | + self.assertEmpty(relu_2_quantize_config.get_output_quantizers(None)) |
| 644 | + self.assertFalse(relu_2_quantize_config.quantize_output) |
| 645 | + |
| 646 | + def testConcatConcatTransformDisablesOutput(self): |
| 647 | + r"""Tests the Concat Transform. |
| 648 | +
|
| 649 | + Input Input Input Input |
| 650 | + Reshape Reshape Reshape Reshape |
| 651 | + \ / \ / |
| 652 | + Concat Concat |
| 653 | + \ / |
| 654 | + Concat |
| 655 | +
|
| 656 | + The Transform should ensure all output FakeQuants are disabled, |
| 657 | + and only a FakeQuant after the last Concat is present. |
| 658 | + """ |
| 659 | + flatten_1 = keras.layers.Flatten() |
| 660 | + flatten_2 = keras.layers.Flatten() |
| 661 | + concat_1 = keras.layers.Concatenate() |
| 662 | + flatten_3 = keras.layers.Flatten() |
| 663 | + flatten_4 = keras.layers.Flatten() |
| 664 | + concat_2 = keras.layers.Concatenate() |
| 665 | + concat = keras.layers.Concatenate() |
| 666 | + |
| 667 | + inp1 = keras.layers.Input((1, 2, 2)) |
| 668 | + inp2 = keras.layers.Input((1, 2, 2)) |
| 669 | + inp3 = keras.layers.Input((1, 2, 2)) |
| 670 | + inp4 = keras.layers.Input((1, 2, 2)) |
| 671 | + x1 = flatten_1(inp1) |
| 672 | + x2 = flatten_2(inp2) |
| 673 | + x3 = flatten_3(inp3) |
| 674 | + x4 = flatten_4(inp4) |
| 675 | + |
| 676 | + y1 = concat_1([x1, x2]) |
| 677 | + y2 = concat_2([x3, x4]) |
| 678 | + z = concat([y1, y2]) |
| 679 | + model = keras.Model([inp1, inp2, inp3, inp4], z) |
| 680 | + reshapes = [flatten_1, flatten_2, flatten_3, flatten_4] |
| 681 | + layer_metadata = {} |
| 682 | + for layer in reshapes: |
| 683 | + layer_metadata[layer.name] = { |
| 684 | + 'quantize_config': |
| 685 | + default_8bit_quantize_registry.Default8BitQuantizeConfig( |
| 686 | + [], [], True)} |
| 687 | + _, updated_metadata = ModelTransformer( |
| 688 | + model, [default_8bit_transforms.ConcatTransform()], |
| 689 | + layer_metadata=layer_metadata).transform() |
| 690 | + |
| 691 | + concat_quantize_config = updated_metadata.get( |
| 692 | + concat.name).get('quantize_config') |
| 693 | + # Concat should quantize the output. |
| 694 | + self.assertIsInstance( |
| 695 | + concat_quantize_config, |
| 696 | + default_8bit_quantize_configs.Default8BitOutputQuantizeConfig) |
| 697 | + self.assertNotEmpty(concat_quantize_config.get_output_quantizers(None)) |
| 698 | + |
| 699 | + # The existing quantize_config should do nothing for outputs. |
| 700 | + for layer in reshapes: |
| 701 | + quantize_config = updated_metadata.get(layer.name).get('quantize_config') |
| 702 | + self.assertIsInstance( |
| 703 | + quantize_config, |
| 704 | + default_8bit_quantize_registry.Default8BitQuantizeConfig) |
| 705 | + self.assertEmpty(quantize_config.get_output_quantizers(layer)) |
| 706 | + self.assertFalse(quantize_config.quantize_output) |
| 707 | + |
579 | 708 |
|
580 | 709 | if __name__ == '__main__':
|
581 | 710 | tf.test.main()
|
0 commit comments