Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/save_load.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ For example:
# Alternatively, load the model directly from the Hugging Face Hub
model = smp.from_pretrained('username/my-model')

Loading pre-trained model with different number of classes for fine-tuning:

.. code:: python

import segmentation_models_pytorch as smp

model = smp.from_pretrained('<path-or-repo-name>', classes=5, strict=False)

Saving model Metrics and Dataset Name
-------------------------------------

Expand Down
6 changes: 3 additions & 3 deletions examples/segformer_inference_pretrained.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
"metadata": {},
"outputs": [],
"source": [
"# fix for HF hub download\n",
"# see PR https://github.com/albumentations-team/albumentations/pull/2171\n",
"!pip install -U git+https://github.com/qubvel/albumentations@patch-2"
"# make sure you have the latest version of the libraries\n",
"!pip install -U segmentation-models-pytorch\n",
"!pip install albumentations matplotlib requests pillow"
]
},
{
Expand Down
61 changes: 42 additions & 19 deletions segmentation_models_pytorch/base/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from typing import TypeVar, Type
import warnings

from typing import TypeVar, Type
from . import initialization as init
from .hub_mixin import SMPHubMixin
from .utils import is_torch_compiling
Expand Down Expand Up @@ -96,23 +97,45 @@
# timm- ported encoders with TimmUniversalEncoder
from segmentation_models_pytorch.encoders import TimmUniversalEncoder

if not isinstance(self.encoder, TimmUniversalEncoder):
return super().load_state_dict(state_dict, **kwargs)

patterns = ["regnet", "res2", "resnest", "mobilenetv3", "gernet"]

is_deprecated_encoder = any(
self.encoder.name.startswith(pattern) for pattern in patterns
)

if is_deprecated_encoder:
keys = list(state_dict.keys())
for key in keys:
new_key = key
if key.startswith("encoder.") and not key.startswith("encoder.model."):
new_key = "encoder.model." + key.removeprefix("encoder.")
if "gernet" in self.encoder.name:
new_key = new_key.replace(".stages.", ".stages_")
state_dict[new_key] = state_dict.pop(key)
if isinstance(self.encoder, TimmUniversalEncoder):
patterns = ["regnet", "res2", "resnest", "mobilenetv3", "gernet"]
is_deprecated_encoder = any(
self.encoder.name.startswith(pattern) for pattern in patterns
)
if is_deprecated_encoder:
keys = list(state_dict.keys())
for key in keys:
new_key = key
if key.startswith("encoder.") and not key.startswith(

Check warning on line 109 in segmentation_models_pytorch/base/model.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/base/model.py#L106-L109

Added lines #L106 - L109 were not covered by tests
"encoder.model."
):
new_key = "encoder.model." + key.removeprefix("encoder.")
if "gernet" in self.encoder.name:
new_key = new_key.replace(".stages.", ".stages_")
state_dict[new_key] = state_dict.pop(key)

Check warning on line 115 in segmentation_models_pytorch/base/model.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/base/model.py#L112-L115

Added lines #L112 - L115 were not covered by tests

# To be able to load weight with mismatched sizes
# We are going to filter mismatched sizes as well if strict=False
strict = kwargs.get("strict", True)
if not strict:
mismatched_keys = []
model_state_dict = self.state_dict()
common_keys = set(model_state_dict.keys()) & set(state_dict.keys())
for key in common_keys:
if model_state_dict[key].shape != state_dict[key].shape:
mismatched_keys.append(
(key, model_state_dict[key].shape, state_dict[key].shape)
)
state_dict.pop(key)

if mismatched_keys:
str_keys = "\n".join(
[
f" - {key}: {s} (weights) -> {m} (model)"
for key, m, s in mismatched_keys
]
)
text = f"\n\n !!!!!! Mismatched keys !!!!!!\n\nYou should TRAIN the model to use it:\n{str_keys}\n"
warnings.warn(text, stacklevel=-1)

return super().load_state_dict(state_dict, **kwargs)
36 changes: 36 additions & 0 deletions tests/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch
import tempfile
import segmentation_models_pytorch as smp

import pytest


def test_from_pretrained_with_mismatched_keys():
orginal_model = smp.Unet(classes=1)

with tempfile.TemporaryDirectory() as temp_dir:
orginal_model.save_pretrained(temp_dir)

# we should catch warning here and check if there specific keys there
with pytest.warns(UserWarning):
restored_model = smp.from_pretrained(temp_dir, classes=2, strict=False)

assert restored_model.segmentation_head[0].out_channels == 2

# verify all the weight are the same expect mismatched ones
original_state_dict = orginal_model.state_dict()
restored_state_dict = restored_model.state_dict()

expected_mismatched_keys = [
"segmentation_head.0.weight",
"segmentation_head.0.bias",
]
mismatched_keys = []
for key in original_state_dict:
if key not in expected_mismatched_keys:
assert torch.allclose(original_state_dict[key], restored_state_dict[key])
else:
mismatched_keys.append(key)

assert len(mismatched_keys) == 2
assert sorted(mismatched_keys) == sorted(expected_mismatched_keys)
Loading