diff --git a/docs/save_load.rst b/docs/save_load.rst index e225c5d0..e90e4eba 100644 --- a/docs/save_load.rst +++ b/docs/save_load.rst @@ -59,6 +59,37 @@ For example: # Or saved and pushed to the Hub simultaneously model.save_pretrained('username/my-model', push_to_hub=True, metrics={'accuracy': 0.95}, dataset='my_dataset') +Saving with preprocessing transform (Albumentations) +---------------------------------------------------- + +You can save the preprocessing transform along with the model and push it to the Hub. +This can be useful when you want to share the model with the preprocessing transform that was used during training, +to make sure that the inference pipeline is consistent with the training pipeline. + +.. code:: python + + import albumentations as A + import segmentation_models_pytorch as smp + + # Define a preprocessing transform for image that would be used during inference + preprocessing_transform = A.Compose([ + A.Resize(256, 256), + A.Normalize() + ]) + + model = smp.Unet() + + directory_or_repo_on_the_hub = "qubvel-hf/unet-with-transform" # / + + # Save the model and transform (and pus ot hub, if needed) + model.save_pretrained(directory_or_repo_on_the_hub, push_to_hub=True) + preprocessing_transform.save_pretrained(directory_or_repo_on_the_hub, push_to_hub=True) + + # Loading transform and model + restored_model = smp.from_pretrained(directory_or_repo_on_the_hub) + restored_transform = A.Compose.from_pretrained(directory_or_repo_on_the_hub) + + print(restored_transform) Conclusion ---------- @@ -71,4 +102,6 @@ By following these steps, you can easily save, share, and load your models, faci :target: https://colab.research.google.com/github/qubvel/segmentation_models.pytorch/blob/main/examples/binary_segmentation_intro.ipynb :alt: Open In Colab - +.. |colab-badge| image:: https://colab.research.google.com/assets/colab-badge.svg + :target: https://colab.research.google.com/github/qubvel/segmentation_models.pytorch/blob/main/examples/save_load_model_and_share_with_hf_hub.ipynb + :alt: Open In Colab diff --git a/examples/save_load_model_and_share_with_hf_hub.ipynb b/examples/save_load_model_and_share_with_hf_hub.ipynb index d6d4c0f4..d27a5a0a 100644 --- a/examples/save_load_model_and_share_with_hf_hub.ipynb +++ b/examples/save_load_model_and_share_with_hf_hub.ipynb @@ -48,7 +48,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -70,7 +70,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -82,9 +82,11 @@ "license: mit\n", "pipeline_tag: image-segmentation\n", "tags:\n", + "- model_hub_mixin\n", + "- pytorch_model_hub_mixin\n", + "- segmentation-models-pytorch\n", "- semantic-segmentation\n", "- pytorch\n", - "- segmentation-models-pytorch\n", "languages:\n", "- python\n", "---\n", @@ -157,7 +159,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "075ae026811542bdb4030e53b943efc7", + "model_id": "1d6fe9d868c24175aa5f23a2893a2c21", "version_major": 2, "version_minor": 0 }, @@ -179,13 +181,13 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "2921a81d7fd747939b4a425cc17d6104", + "model_id": "2f4f5e4973e44f9a857e89d9ac707b53", "version_major": 2, "version_minor": 0 }, @@ -199,10 +201,10 @@ { "data": { "text/plain": [ - "CommitInfo(commit_url='https://huggingface.co/qubvel-hf/unet-with-metadata/commit/9f821c7bc3a12db827c0da96a31f354ec6ba5253', commit_message='Push model using huggingface_hub.', commit_description='', oid='9f821c7bc3a12db827c0da96a31f354ec6ba5253', pr_url=None, pr_revision=None, pr_num=None)" + "CommitInfo(commit_url='https://huggingface.co/qubvel-hf/unet-with-metadata/commit/4ac3d2925d34cf183dc79a2e21b6e2f4bfe87586', commit_message='Push model using huggingface_hub.', commit_description='', oid='4ac3d2925d34cf183dc79a2e21b6e2f4bfe87586', pr_url=None, pr_revision=None, pr_num=None)" ] }, - "execution_count": 8, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -224,6 +226,158 @@ "\n", "# see result here https://huggingface.co/qubvel-hf/unet-with-metadata" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Save model with preprocessing (using albumentations)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -U albumentations numpy==1.*" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import albumentations as A\n", + "import segmentation_models_pytorch as smp" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# define a preprocessing transform for image that would be used during inference\n", + "preprocessing_transform = A.Compose([\n", + " A.Resize(256, 256),\n", + " A.Normalize()\n", + "])\n", + "\n", + "model = smp.Unet()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "1aa3f4db4cd2489baeac3b844977d5a2", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "model.safetensors: 0%| | 0.00/97.8M [00:00