Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion docs/save_load.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,37 @@ For example:
# Or saved and pushed to the Hub simultaneously
model.save_pretrained('username/my-model', push_to_hub=True, metrics={'accuracy': 0.95}, dataset='my_dataset')

Saving with preprocessing transform (Albumentations)
----------------------------------------------------

You can save the preprocessing transform along with the model and push it to the Hub.
This can be useful when you want to share the model with the preprocessing transform that was used during training,
to make sure that the inference pipeline is consistent with the training pipeline.

.. code:: python

import albumentations as A
import segmentation_models_pytorch as smp

# Define a preprocessing transform for image that would be used during inference
preprocessing_transform = A.Compose([
A.Resize(256, 256),
A.Normalize()
])

model = smp.Unet()

directory_or_repo_on_the_hub = "qubvel-hf/unet-with-transform" # <username>/<repo-name>

# Save the model and transform (and pus ot hub, if needed)
model.save_pretrained(directory_or_repo_on_the_hub, push_to_hub=True)
preprocessing_transform.save_pretrained(directory_or_repo_on_the_hub, push_to_hub=True)

# Loading transform and model
restored_model = smp.from_pretrained(directory_or_repo_on_the_hub)
restored_transform = A.Compose.from_pretrained(directory_or_repo_on_the_hub)

print(restored_transform)

Conclusion
----------
Expand All @@ -71,4 +102,6 @@ By following these steps, you can easily save, share, and load your models, faci
:target: https://colab.research.google.com/github/qubvel/segmentation_models.pytorch/blob/main/examples/binary_segmentation_intro.ipynb
:alt: Open In Colab


.. |colab-badge| image:: https://colab.research.google.com/assets/colab-badge.svg
:target: https://colab.research.google.com/github/qubvel/segmentation_models.pytorch/blob/main/examples/save_load_model_and_share_with_hf_hub.ipynb
:alt: Open In Colab
170 changes: 162 additions & 8 deletions examples/save_load_model_and_share_with_hf_hub.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -70,7 +70,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 4,
"metadata": {},
"outputs": [
{
Expand All @@ -82,9 +82,11 @@
"license: mit\n",
"pipeline_tag: image-segmentation\n",
"tags:\n",
"- model_hub_mixin\n",
"- pytorch_model_hub_mixin\n",
"- segmentation-models-pytorch\n",
"- semantic-segmentation\n",
"- pytorch\n",
"- segmentation-models-pytorch\n",
"languages:\n",
"- python\n",
"---\n",
Expand Down Expand Up @@ -157,7 +159,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "075ae026811542bdb4030e53b943efc7",
"model_id": "1d6fe9d868c24175aa5f23a2893a2c21",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -179,13 +181,13 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2921a81d7fd747939b4a425cc17d6104",
"model_id": "2f4f5e4973e44f9a857e89d9ac707b53",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -199,10 +201,10 @@
{
"data": {
"text/plain": [
"CommitInfo(commit_url='https://huggingface.co/qubvel-hf/unet-with-metadata/commit/9f821c7bc3a12db827c0da96a31f354ec6ba5253', commit_message='Push model using huggingface_hub.', commit_description='', oid='9f821c7bc3a12db827c0da96a31f354ec6ba5253', pr_url=None, pr_revision=None, pr_num=None)"
"CommitInfo(commit_url='https://huggingface.co/qubvel-hf/unet-with-metadata/commit/4ac3d2925d34cf183dc79a2e21b6e2f4bfe87586', commit_message='Push model using huggingface_hub.', commit_description='', oid='4ac3d2925d34cf183dc79a2e21b6e2f4bfe87586', pr_url=None, pr_revision=None, pr_num=None)"
]
},
"execution_count": 8,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -224,6 +226,158 @@
"\n",
"# see result here https://huggingface.co/qubvel-hf/unet-with-metadata"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Save model with preprocessing (using albumentations)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!pip install -U albumentations numpy==1.*"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import albumentations as A\n",
"import segmentation_models_pytorch as smp"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# define a preprocessing transform for image that would be used during inference\n",
"preprocessing_transform = A.Compose([\n",
" A.Resize(256, 256),\n",
" A.Normalize()\n",
"])\n",
"\n",
"model = smp.Unet()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1aa3f4db4cd2489baeac3b844977d5a2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"model.safetensors: 0%| | 0.00/97.8M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"CommitInfo(commit_url='https://huggingface.co/qubvel-hf/unet-with-transform/commit/680dad16431fa6efbb25832d33a24056bdf7dc1a', commit_message='Push transform using huggingface_hub.', commit_description='', oid='680dad16431fa6efbb25832d33a24056bdf7dc1a', pr_url=None, pr_revision=None, pr_num=None)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"directory_or_repo_on_the_hub = \"qubvel-hf/unet-with-transform\"\n",
"\n",
"# save the model\n",
"model.save_pretrained(directory_or_repo_on_the_hub, push_to_hub=True)\n",
"\n",
"# save transform\n",
"preprocessing_transform.save_pretrained(directory_or_repo_on_the_hub, push_to_hub=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, let's restore model and preprocessing transform for inference:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading weights from local directory\n",
"Compose([\n",
" Resize(p=1.0, height=256, width=256, interpolation=1),\n",
" Normalize(p=1.0, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, normalization='standard'),\n",
"], p=1.0, bbox_params=None, keypoint_params=None, additional_targets={}, is_check_shapes=True)\n"
]
}
],
"source": [
"restored_model = smp.from_pretrained(directory_or_repo_on_the_hub)\n",
"restored_transform = A.Compose.from_pretrained(directory_or_repo_on_the_hub)\n",
"\n",
"print(restored_transform)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Compose([\n",
" HorizontalFlip(p=0.5),\n",
" RandomBrightnessContrast(p=0.2, brightness_limit=(-0.2, 0.2), contrast_limit=(-0.2, 0.2), brightness_by_max=True),\n",
" ShiftScaleRotate(p=0.5, shift_limit_x=(-0.0625, 0.0625), shift_limit_y=(-0.0625, 0.0625), scale_limit=(-0.09999999999999998, 0.10000000000000009), rotate_limit=(-45, 45), interpolation=1, border_mode=4, value=0.0, mask_value=0.0, rotate_method='largest_box'),\n",
"], p=1.0, bbox_params=None, keypoint_params=None, additional_targets={}, is_check_shapes=True)\n"
]
}
],
"source": [
"# You can also save training augmentations to the Hub too (and load it back)!\n",
"#! Just make sure to provide key=\"train\" when saving and loading the augmentations.\n",
"\n",
"train_augmentations = A.Compose([\n",
" A.HorizontalFlip(p=0.5),\n",
" A.RandomBrightnessContrast(p=0.2),\n",
" A.ShiftScaleRotate(p=0.5),\n",
"])\n",
"\n",
"train_augmentations.save_pretrained(directory_or_repo_on_the_hub, key=\"train\", push_to_hub=True)\n",
"\n",
"restored_train_augmentations = A.Compose.from_pretrained(directory_or_repo_on_the_hub, key=\"train\")\n",
"print(restored_train_augmentations)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"See saved model and `albumentations` configs on the hub: https://huggingface.co/qubvel-hf/unet-with-transform/tree/main"
]
}
],
"metadata": {
Expand Down