|
102 | 102 | "outputs": [],
|
103 | 103 | "source": [
|
104 | 104 | "! pip uninstall -y tensorflow\n",
|
105 |
| - "! pip install -q tf-nightly==2.2.0.dev20200315\n", |
106 |
| - "! pip install -q --extra-index-url=https://test.pypi.org/simple/ tensorflow-model-optimization==0.3.0.dev6\n", |
| 105 | + "! pip install -q tf-nightly\n", |
| 106 | + "! pip install -q tensorflow-model-optimization\n", |
107 | 107 | "\n",
|
108 | 108 | "import tensorflow as tf\n",
|
109 | 109 | "import numpy as np\n",
|
|
610 | 610 | "id": "YmyhI_bzWb2w"
|
611 | 611 | },
|
612 | 612 | "source": [
|
613 |
| - "This example uses the `DefaultDenseQuantizeConfig` to quantize a `Dense` layer. In practice, the layer\n", |
614 |
| - "can be any custom Keras layer.\n", |
| 613 | + "This example uses the `DefaultDenseQuantizeConfig` to quantize the `CustomLayer`.\n", |
615 | 614 | "\n",
|
616 | 615 | "Applying the configuration is the same across\n",
|
617 | 616 | "the \"Experiment with quantization\" use cases.\n",
|
618 |
| - " * Apply `tfmot.quantization.keras.quantize_annotate_layer` to the `Dense` layer and pass in the `QuantizeConfig`.\n", |
| 617 | + " * Apply `tfmot.quantization.keras.quantize_annotate_layer` to the `CustomLayer` and pass in the `QuantizeConfig`.\n", |
619 | 618 | " * Use\n",
|
620 | 619 | "`tfmot.quantization.keras.quantize_annotate_model` to continue to quantize the rest of the model with the API defaults.\n",
|
621 | 620 | "\n"
|
|
635 | 634 | "quantize_annotate_model = tfmot.quantization.keras.quantize_annotate_model\n",
|
636 | 635 | "quantize_scope = tfmot.quantization.keras.quantize_scope\n",
|
637 | 636 | "\n",
|
| 637 | + "class CustomLayer(tf.keras.layers.Dense):\n", |
| 638 | + " pass\n", |
| 639 | + "\n", |
638 | 640 | "model = quantize_annotate_model(tf.keras.Sequential([\n",
|
639 |
| - " quantize_annotate_layer(tf.keras.layers.Dense(20, input_shape=(20,)), DefaultDenseQuantizeConfig()),\n", |
| 641 | + " quantize_annotate_layer(CustomLayer(20, input_shape=(20,)), DefaultDenseQuantizeConfig()),\n", |
640 | 642 | " tf.keras.layers.Flatten()\n",
|
641 | 643 | "]))\n",
|
642 | 644 | "\n",
|
643 |
| - "# `quantize_apply` requires mentioning `DefaultDenseQuantizeConfig` with `quantize_scope`:\n", |
| 645 | + "# `quantize_apply` requires mentioning `DefaultDenseQuantizeConfig` with `quantize_scope`\n", |
| 646 | + "# as well as the custom Keras layer.\n", |
644 | 647 | "with quantize_scope(\n",
|
645 |
| - " {'DefaultDenseQuantizeConfig': DefaultDenseQuantizeConfig}):\n", |
| 648 | + " {'DefaultDenseQuantizeConfig': DefaultDenseQuantizeConfig,\n", |
| 649 | + " 'CustomLayer': CustomLayer}):\n", |
646 | 650 | " # Use `quantize_apply` to actually make the model quantization aware.\n",
|
647 | 651 | " quant_aware_model = tfmot.quantization.keras.quantize_apply(model)\n",
|
648 | 652 | "\n",
|
|
864 | 868 | " # Not needed. No new TensorFlow variables needed.\n",
|
865 | 869 | " return {}\n",
|
866 | 870 | "\n",
|
867 |
| - " def __call__(self, inputs, step, training, **kwargs):\n", |
| 871 | + " def __call__(self, inputs, training, weights, **kwargs):\n", |
868 | 872 | " return tf.keras.backend.clip(inputs, -1.0, 1.0)\n",
|
869 | 873 | "\n",
|
870 | 874 | " def get_config(self):\n",
|
|
0 commit comments