|
27 | 27 | from tensorflow_model_optimization.python.core.quantization.keras import quantize_config as quantize_config_mod
|
28 | 28 | from tensorflow_model_optimization.python.core.quantization.keras import quantize_layer
|
29 | 29 | from tensorflow_model_optimization.python.core.quantization.keras import quantize_wrapper as quantize_wrapper_mod
|
| 30 | +from tensorflow_model_optimization.python.core.quantization.keras import quantizers |
30 | 31 | from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_registry
|
31 | 32 |
|
32 | 33 | quantize_annotate_layer = quantize.quantize_annotate_layer
|
@@ -513,6 +514,69 @@ def testQuantizeApply_RunsWhenNestedModelNotAnnotated(self):
|
513 | 514 |
|
514 | 515 | quantize_apply(annotated_model)
|
515 | 516 |
|
| 517 | + class CustomConvLayer(tf.keras.layers.Layer): |
| 518 | + |
| 519 | + def __init__(self, name=None, **kwargs): |
| 520 | + super().__init__(name=name, **kwargs) |
| 521 | + self.conv1 = tf.keras.layers.Conv2D(2, 2) |
| 522 | + |
| 523 | + def build(self, input_shape): |
| 524 | + self.conv1.build(input_shape) |
| 525 | + |
| 526 | + def call(self, inputs): |
| 527 | + return self.conv1(inputs) |
| 528 | + |
| 529 | + def get_config(self): |
| 530 | + return {'name': self.name} |
| 531 | + |
| 532 | + class CustomConvQuantizeConfig(quantize_config_mod.QuantizeConfig): |
| 533 | + |
| 534 | + def get_weights_and_quantizers(self, layer): |
| 535 | + return [(layer.conv1.kernel, quantizers.LastValueQuantizer( |
| 536 | + num_bits=8, symmetric=True, narrow_range=False, per_axis=False)),] |
| 537 | + |
| 538 | + def get_activations_and_quantizers(self, layer): |
| 539 | + return [] |
| 540 | + |
| 541 | + def set_quantize_weights(self, layer, quantize_weights): |
| 542 | + # layer.conv1._kernel_bak = layer.conv1.kernel |
| 543 | + layer.conv1.kernel = quantize_weights[0] |
| 544 | + |
| 545 | + def set_quantize_activations(self, layer, quantize_activations): |
| 546 | + pass |
| 547 | + |
| 548 | + def get_output_quantizers(self, layer): |
| 549 | + return [] |
| 550 | + |
| 551 | + def get_config(self): |
| 552 | + return {} |
| 553 | + |
| 554 | + def testQuantizeApply_KeepTrainableWeightOrder(self): |
| 555 | + layer = self.CustomConvLayer(input_shape=(28, 28, 3)) |
| 556 | + model = keras.Sequential([layer]) |
| 557 | + |
| 558 | + def apply_quantization_to_dense(layer): |
| 559 | + if isinstance(layer, self.CustomConvLayer): |
| 560 | + return quantize_annotate_layer( |
| 561 | + layer, quantize_config=self.CustomConvQuantizeConfig()) |
| 562 | + return layer |
| 563 | + |
| 564 | + annotated_model = tf.keras.models.clone_model( |
| 565 | + model, |
| 566 | + clone_function=apply_quantization_to_dense, |
| 567 | + ) |
| 568 | + |
| 569 | + with quantize.quantize_scope({ |
| 570 | + 'CustomConvQuantizeConfig': self.CustomConvQuantizeConfig, |
| 571 | + 'CustomConvLayer': self.CustomConvLayer |
| 572 | + }): |
| 573 | + quant_aware_model = quantize_apply(annotated_model) |
| 574 | + |
| 575 | + self._assert_weights_different_objects( |
| 576 | + model.trainable_weights, quant_aware_model.trainable_weights) |
| 577 | + self._assert_weights_equal_value( |
| 578 | + model.trainable_weights, quant_aware_model.trainable_weights) |
| 579 | + |
516 | 580 |
|
517 | 581 | if __name__ == '__main__':
|
518 | 582 | tf.test.main()
|
0 commit comments