diff --git a/examples/save_load_model_and_share_with_hf_hub.ipynb b/examples/save_load_model_and_share_with_hf_hub.ipynb new file mode 100644 index 00000000..d6d4c0f4 --- /dev/null +++ b/examples/save_load_model_and_share_with_hf_hub.ipynb @@ -0,0 +1,250 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import segmentation_models_pytorch as smp" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Save to local directory and load back" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading weights from local directory\n" + ] + } + ], + "source": [ + "model = smp.Unet()\n", + "\n", + "# save the model\n", + "model.save_pretrained(\"saved-model-dir/unet/\")\n", + "\n", + "# load the model\n", + "restored_model = smp.from_pretrained(\"saved-model-dir/unet/\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Save model with additional metadata" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "model = smp.Unet()\n", + "\n", + "# save the model\n", + "model.save_pretrained(\n", + " \"saved-model-dir/unet-with-metadata/\",\n", + "\n", + " # additional information to be saved with the model\n", + " # only \"dataset\" and \"metrics\" are supported\n", + " dataset=\"PASCAL VOC\", # only string name is supported\n", + " metrics={ # should be a dictionary with metric name as key and metric value as value\n", + " \"mIoU\": 0.95,\n", + " \"accuracy\": 0.96\n", + " }\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "---\n", + "library_name: segmentation-models-pytorch\n", + "license: mit\n", + "pipeline_tag: image-segmentation\n", + "tags:\n", + "- semantic-segmentation\n", + "- pytorch\n", + "- segmentation-models-pytorch\n", + "languages:\n", + "- python\n", + "---\n", + "# Unet Model Card\n", + "\n", + "Table of Contents:\n", + "- [Load trained model](#load-trained-model)\n", + "- [Model init parameters](#model-init-parameters)\n", + "- [Model metrics](#model-metrics)\n", + "- [Dataset](#dataset)\n", + "\n", + "## Load trained model\n", + "```python\n", + "import segmentation_models_pytorch as smp\n", + "\n", + "model = smp.from_pretrained(\"\")\n", + "```\n", + "\n", + "## Model init parameters\n", + "```python\n", + "model_init_params = {\n", + " \"encoder_name\": \"resnet34\",\n", + " \"encoder_depth\": 5,\n", + " \"encoder_weights\": \"imagenet\",\n", + " \"decoder_use_batchnorm\": True,\n", + " \"decoder_channels\": (256, 128, 64, 32, 16),\n", + " \"decoder_attention_type\": None,\n", + " \"in_channels\": 3,\n", + " \"classes\": 1,\n", + " \"activation\": None,\n", + " \"aux_params\": None\n", + "}\n", + "```\n", + "\n", + "## Model metrics\n", + "```json\n", + "{\n", + " \"mIoU\": 0.95,\n", + " \"accuracy\": 0.96\n", + "}\n", + "```\n", + "\n", + "## Dataset\n", + "Dataset name: PASCAL VOC\n", + "\n", + "## More Information\n", + "- Library: https://github.com/qubvel/segmentation_models.pytorch\n", + "- Docs: https://smp.readthedocs.io/en/latest/\n", + "\n", + "This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin)" + ] + } + ], + "source": [ + "!cat \"saved-model-dir/unet-with-metadata/README.md\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Share model with HF Hub" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "075ae026811542bdb4030e53b943efc7", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(HTML(value='
=0.5.0 pretrainedmodels==0.7.4 efficientnet-pytorch==0.7.1 timm==0.9.7 +huggingface_hub>=0.24.6 tqdm pillow diff --git a/segmentation_models_pytorch/__init__.py b/segmentation_models_pytorch/__init__.py index 6a92457c..d3778ecc 100644 --- a/segmentation_models_pytorch/__init__.py +++ b/segmentation_models_pytorch/__init__.py @@ -1,3 +1,5 @@ +import warnings + from . import datasets from . import encoders from . import decoders @@ -20,6 +22,9 @@ from typing import Optional as _Optional import torch as _torch +# Suppress the specific SyntaxWarning for `pretrainedmodels` +warnings.filterwarnings("ignore", message="is with a literal", category=SyntaxWarning) + def create_model( arch: str, diff --git a/segmentation_models_pytorch/base/hub_mixin.py b/segmentation_models_pytorch/base/hub_mixin.py index 0e642d2c..8095c5b8 100644 --- a/segmentation_models_pytorch/base/hub_mixin.py +++ b/segmentation_models_pytorch/base/hub_mixin.py @@ -26,7 +26,7 @@ ```python import segmentation_models_pytorch as smp -model = smp.{{ model_name }}.from_pretrained("{{ save_directory | default("", true)}}") +model = smp.from_pretrained("") ``` ## Model init parameters @@ -61,23 +61,22 @@ def _format_parameters(parameters: dict): class SMPHubMixin(PyTorchModelHubMixin): def generate_model_card(self, *args, **kwargs) -> ModelCard: - model_parameters_json = _format_parameters(self._hub_mixin_config) - directory = self._save_directory if hasattr(self, "_save_directory") else None - repo_id = self._repo_id if hasattr(self, "_repo_id") else None - repo_or_directory = repo_id if repo_id is not None else directory - - metrics = self._metrics if hasattr(self, "_metrics") else None - dataset = self._dataset if hasattr(self, "_dataset") else None + model_parameters_json = _format_parameters(self.config) + metrics = kwargs.get("metrics", None) + dataset = kwargs.get("dataset", None) if metrics is not None: metrics = json.dumps(metrics, indent=4) metrics = f"```json\n{metrics}\n```" + tags = self._hub_mixin_info.model_card_data.get("tags", []) or [] + tags.extend(["segmentation-models-pytorch", "semantic-segmentation", "pytorch"]) + model_card_data = ModelCardData( languages=["python"], library_name="segmentation-models-pytorch", license="mit", - tags=["semantic-segmentation", "pytorch", "segmentation-models-pytorch"], + tags=tags, pipeline_tag="image-segmentation", ) model_card = ModelCard.from_template( @@ -86,64 +85,49 @@ def generate_model_card(self, *args, **kwargs) -> ModelCard: repo_url="https://github.com/qubvel/segmentation_models.pytorch", docs_url="https://smp.readthedocs.io/en/latest/", model_parameters=model_parameters_json, - save_directory=repo_or_directory, model_name=self.__class__.__name__, metrics=metrics, dataset=dataset, ) return model_card - def _set_attrs_from_kwargs(self, attrs, kwargs): - for attr in attrs: - if attr in kwargs: - setattr(self, f"_{attr}", kwargs.pop(attr)) - - def _del_attrs(self, attrs): - for attr in attrs: - if hasattr(self, f"_{attr}"): - delattr(self, f"_{attr}") - @wraps(PyTorchModelHubMixin.save_pretrained) def save_pretrained( self, save_directory: Union[str, Path], *args, **kwargs ) -> Optional[str]: - # set additional attributes to be used in generate_model_card - self._save_directory = save_directory - self._set_attrs_from_kwargs(["metrics", "dataset"], kwargs) + model_card_kwargs = kwargs.pop("model_card_kwargs", {}) + if "dataset" in kwargs: + model_card_kwargs["dataset"] = kwargs.pop("dataset") + if "metrics" in kwargs: + model_card_kwargs["metrics"] = kwargs.pop("metrics") + kwargs["model_card_kwargs"] = model_card_kwargs - # set additional attribute to be used in from_pretrained - self._hub_mixin_config["_model_class"] = self.__class__.__name__ + # set additional attribute to be able to deserialize the model + self.config["_model_class"] = self.__class__.__name__ try: # call the original save_pretrained result = super().save_pretrained(save_directory, *args, **kwargs) finally: - # delete the additional attributes - self._del_attrs(["save_directory", "metrics", "dataset"]) - self._hub_mixin_config.pop("_model_class", None) + self.config.pop("_model_class", None) return result - @wraps(PyTorchModelHubMixin.push_to_hub) - def push_to_hub(self, repo_id: str, *args, **kwargs): - self._repo_id = repo_id - self._set_attrs_from_kwargs(["metrics", "dataset"], kwargs) - result = super().push_to_hub(repo_id, *args, **kwargs) - self._del_attrs(["repo_id", "metrics", "dataset"]) - return result - @property - def config(self): + def config(self) -> dict: return self._hub_mixin_config @wraps(PyTorchModelHubMixin.from_pretrained) def from_pretrained(pretrained_model_name_or_path: str, *args, **kwargs): - config_path = hf_hub_download( - pretrained_model_name_or_path, - filename="config.json", - revision=kwargs.get("revision", None), - ) + config_path = Path(pretrained_model_name_or_path) / "config.json" + if not config_path.exists(): + config_path = hf_hub_download( + pretrained_model_name_or_path, + filename="config.json", + revision=kwargs.get("revision", None), + ) + with open(config_path, "r") as f: config = json.load(f) model_class_name = config.pop("_model_class")