Skip to content

Commit f2e3f89

Browse files
committed
Merge branch 'main' into pr/vedantdalimkar/1079
2 parents 19eeebe + a084685 commit f2e3f89

35 files changed

+835
-192
lines changed

README.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,27 @@ The main features of the library are:
2525
- 800+ **pretrained** convolution- and transform-based encoders, including [timm](https://github.com/huggingface/pytorch-image-models) support
2626
- Popular metrics and losses for training routines (Dice, Jaccard, Tversky, ...)
2727
- ONNX export and torch script/trace/compile friendly
28+
29+
### Community-Driven Project, Supported By
30+
<table>
31+
<tr>
32+
<td align="center" vertical-align="center">
33+
<a href="https://withoutbg.com/?utm_source=smp&utm_medium=github_readme&utm_campaign=sponsorship" >
34+
<img src="https://withoutbg.com/images/logo-social.png" width="70px;" alt="withoutBG API Logo" />
35+
</a>
36+
</td>
37+
<td align="center" vertical-align="center">
38+
<b>withoutBG API</b>
39+
<br />
40+
<a href="https://withoutbg.com/?utm_source=smp&utm_medium=github_readme&utm_campaign=sponsorship">https://withoutbg.com</a>
41+
<br />
42+
<p width="200px">
43+
High-quality background removal API
44+
<br/>
45+
</p>
46+
</td>
47+
</tr>
48+
</table>
2849

2950
### [📚 Project Documentation 📚](http://smp.readthedocs.io/)
3051

docs/save_load.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,14 @@ For example:
4040
# Alternatively, load the model directly from the Hugging Face Hub
4141
model = smp.from_pretrained('username/my-model')
4242
43+
Loading pre-trained model with different number of classes for fine-tuning:
44+
45+
.. code:: python
46+
47+
import segmentation_models_pytorch as smp
48+
49+
model = smp.from_pretrained('<path-or-repo-name>', classes=5, strict=False)
50+
4351
Saving model Metrics and Dataset Name
4452
-------------------------------------
4553

examples/segformer_inference_pretrained.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
"metadata": {},
1414
"outputs": [],
1515
"source": [
16-
"# fix for HF hub download\n",
17-
"# see PR https://github.com/albumentations-team/albumentations/pull/2171\n",
18-
"!pip install -U git+https://github.com/qubvel/albumentations@patch-2"
16+
"# make sure you have the latest version of the libraries\n",
17+
"!pip install -U segmentation-models-pytorch\n",
18+
"!pip install albumentations matplotlib requests pillow"
1919
]
2020
},
2121
{

requirements/docs.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
autodocsumm==0.2.14
2-
huggingface-hub==0.29.1
2+
huggingface-hub==0.30.1
33
six==1.17.0
4-
sphinx==8.2.1
4+
sphinx==8.2.3
55
sphinx-book-theme==1.1.4

requirements/required.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
huggingface_hub==0.29.1
2-
numpy==2.2.3
3-
pillow==11.1.0
4-
safetensors==0.5.2
1+
huggingface_hub==0.30.1
2+
numpy==2.2.4
3+
pillow==11.2.0
4+
safetensors==0.5.3
55
timm==1.0.15
66
torch==2.6.0
77
torchvision==0.21.0

requirements/test.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
gitpython==3.1.44
22
packaging==24.2
3-
pytest==8.3.4
3+
pytest==8.3.5
44
pytest-xdist==3.6.1
5-
pytest-cov==6.0.0
6-
ruff==0.9.7
7-
setuptools==75.8.0
5+
pytest-cov==6.1.0
6+
ruff==0.11.3
7+
setuptools==78.1.0

segmentation_models_pytorch/base/model.py

Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
2-
from typing import TypeVar, Type
2+
import warnings
33

4+
from typing import TypeVar, Type
45
from . import initialization as init
56
from .hub_mixin import SMPHubMixin
67
from .utils import is_torch_compiling
@@ -96,23 +97,45 @@ def load_state_dict(self, state_dict, **kwargs):
9697
# timm- ported encoders with TimmUniversalEncoder
9798
from segmentation_models_pytorch.encoders import TimmUniversalEncoder
9899

99-
if not isinstance(self.encoder, TimmUniversalEncoder):
100-
return super().load_state_dict(state_dict, **kwargs)
101-
102-
patterns = ["regnet", "res2", "resnest", "mobilenetv3", "gernet"]
103-
104-
is_deprecated_encoder = any(
105-
self.encoder.name.startswith(pattern) for pattern in patterns
106-
)
107-
108-
if is_deprecated_encoder:
109-
keys = list(state_dict.keys())
110-
for key in keys:
111-
new_key = key
112-
if key.startswith("encoder.") and not key.startswith("encoder.model."):
113-
new_key = "encoder.model." + key.removeprefix("encoder.")
114-
if "gernet" in self.encoder.name:
115-
new_key = new_key.replace(".stages.", ".stages_")
116-
state_dict[new_key] = state_dict.pop(key)
100+
if isinstance(self.encoder, TimmUniversalEncoder):
101+
patterns = ["regnet", "res2", "resnest", "mobilenetv3", "gernet"]
102+
is_deprecated_encoder = any(
103+
self.encoder.name.startswith(pattern) for pattern in patterns
104+
)
105+
if is_deprecated_encoder:
106+
keys = list(state_dict.keys())
107+
for key in keys:
108+
new_key = key
109+
if key.startswith("encoder.") and not key.startswith(
110+
"encoder.model."
111+
):
112+
new_key = "encoder.model." + key.removeprefix("encoder.")
113+
if "gernet" in self.encoder.name:
114+
new_key = new_key.replace(".stages.", ".stages_")
115+
state_dict[new_key] = state_dict.pop(key)
116+
117+
# To be able to load weight with mismatched sizes
118+
# We are going to filter mismatched sizes as well if strict=False
119+
strict = kwargs.get("strict", True)
120+
if not strict:
121+
mismatched_keys = []
122+
model_state_dict = self.state_dict()
123+
common_keys = set(model_state_dict.keys()) & set(state_dict.keys())
124+
for key in common_keys:
125+
if model_state_dict[key].shape != state_dict[key].shape:
126+
mismatched_keys.append(
127+
(key, model_state_dict[key].shape, state_dict[key].shape)
128+
)
129+
state_dict.pop(key)
130+
131+
if mismatched_keys:
132+
str_keys = "\n".join(
133+
[
134+
f" - {key}: {s} (weights) -> {m} (model)"
135+
for key, m, s in mismatched_keys
136+
]
137+
)
138+
text = f"\n\n !!!!!! Mismatched keys !!!!!!\n\nYou should TRAIN the model to use it:\n{str_keys}\n"
139+
warnings.warn(text, stacklevel=-1)
117140

118141
return super().load_state_dict(state_dict, **kwargs)

segmentation_models_pytorch/base/modules.py

Lines changed: 91 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Any, Dict, Union
2+
13
import torch
24
import torch.nn as nn
35

@@ -7,43 +9,109 @@
79
InPlaceABN = None
810

911

12+
def get_norm_layer(
13+
use_norm: Union[bool, str, Dict[str, Any]], out_channels: int
14+
) -> nn.Module:
15+
supported_norms = ("inplace", "batchnorm", "identity", "layernorm", "instancenorm")
16+
17+
# Step 1. Convert tot dict representation
18+
19+
## Check boolean
20+
if use_norm is True:
21+
norm_params = {"type": "batchnorm"}
22+
elif use_norm is False:
23+
norm_params = {"type": "identity"}
24+
25+
## Check string
26+
elif isinstance(use_norm, str):
27+
norm_str = use_norm.lower()
28+
if norm_str == "inplace":
29+
norm_params = {
30+
"type": "inplace",
31+
"activation": "leaky_relu",
32+
"activation_param": 0.0,
33+
}
34+
elif norm_str in supported_norms:
35+
norm_params = {"type": norm_str}
36+
else:
37+
raise ValueError(
38+
f"Unrecognized normalization type string provided: {use_norm}. Should be in "
39+
f"{supported_norms}"
40+
)
41+
42+
## Check dict
43+
elif isinstance(use_norm, dict):
44+
norm_params = use_norm
45+
46+
else:
47+
raise ValueError(
48+
f"Invalid type for use_norm should either be a bool (batchnorm/identity), "
49+
f"a string in {supported_norms}, or a dict like {{'type': 'batchnorm', **kwargs}}"
50+
)
51+
52+
# Step 2. Check if the dict is valid
53+
if "type" not in norm_params:
54+
raise ValueError(
55+
f"Malformed dictionary given in use_norm: {use_norm}. Should contain key 'type'."
56+
)
57+
if norm_params["type"] not in supported_norms:
58+
raise ValueError(
59+
f"Unrecognized normalization type string provided: {use_norm}. Should be in {supported_norms}"
60+
)
61+
if norm_params["type"] == "inplace" and InPlaceABN is None:
62+
raise RuntimeError(
63+
"In order to use `use_norm='inplace'` the inplace_abn package must be installed. Use:\n"
64+
" $ pip install -U wheel setuptools\n"
65+
" $ pip install inplace_abn --no-build-isolation\n"
66+
"Also see: https://github.com/mapillary/inplace_abn"
67+
)
68+
69+
# Step 3. Initialize the norm layer
70+
norm_type = norm_params["type"]
71+
norm_kwargs = {k: v for k, v in norm_params.items() if k != "type"}
72+
73+
if norm_type == "inplace":
74+
norm = InPlaceABN(out_channels, **norm_kwargs)
75+
elif norm_type == "batchnorm":
76+
norm = nn.BatchNorm2d(out_channels, **norm_kwargs)
77+
elif norm_type == "identity":
78+
norm = nn.Identity()
79+
elif norm_type == "layernorm":
80+
norm = nn.LayerNorm(out_channels, **norm_kwargs)
81+
elif norm_type == "instancenorm":
82+
norm = nn.InstanceNorm2d(out_channels, **norm_kwargs)
83+
else:
84+
raise ValueError(f"Unrecognized normalization type: {norm_type}")
85+
86+
return norm
87+
88+
1089
class Conv2dReLU(nn.Sequential):
1190
def __init__(
1291
self,
13-
in_channels,
14-
out_channels,
15-
kernel_size,
16-
padding=0,
17-
stride=1,
18-
use_batchnorm=True,
92+
in_channels: int,
93+
out_channels: int,
94+
kernel_size: int,
95+
padding: int = 0,
96+
stride: int = 1,
97+
use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
1998
):
20-
if use_batchnorm == "inplace" and InPlaceABN is None:
21-
raise RuntimeError(
22-
"In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. "
23-
+ "To install see: https://github.com/mapillary/inplace_abn"
24-
)
99+
norm = get_norm_layer(use_norm, out_channels)
25100

101+
is_identity = isinstance(norm, nn.Identity)
26102
conv = nn.Conv2d(
27103
in_channels,
28104
out_channels,
29105
kernel_size,
30106
stride=stride,
31107
padding=padding,
32-
bias=not (use_batchnorm),
108+
bias=is_identity,
33109
)
34-
relu = nn.ReLU(inplace=True)
35-
36-
if use_batchnorm == "inplace":
37-
bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0)
38-
relu = nn.Identity()
39110

40-
elif use_batchnorm and use_batchnorm != "inplace":
41-
bn = nn.BatchNorm2d(out_channels)
42-
43-
else:
44-
bn = nn.Identity()
111+
is_inplaceabn = InPlaceABN is not None and isinstance(norm, InPlaceABN)
112+
activation = nn.Identity() if is_inplaceabn else nn.ReLU(inplace=True)
45113

46-
super(Conv2dReLU, self).__init__(conv, bn, relu)
114+
super(Conv2dReLU, self).__init__(conv, norm, activation)
47115

48116

49117
class SCSEModule(nn.Module):

segmentation_models_pytorch/decoders/fpn/decoder.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
2525

2626

2727
class FPNBlock(nn.Module):
28-
def __init__(self, pyramid_channels: int, skip_channels: int):
28+
def __init__(
29+
self,
30+
pyramid_channels: int,
31+
skip_channels: int,
32+
interpolation_mode: str = "nearest",
33+
):
2934
super().__init__()
3035
self.skip_conv = nn.Conv2d(skip_channels, pyramid_channels, kernel_size=1)
36+
self.interpolation_mode = interpolation_mode
3137

3238
def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
33-
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
39+
x = F.interpolate(x, scale_factor=2.0, mode=self.interpolation_mode)
3440
skip = self.skip_conv(skip)
3541
x = x + skip
3642
return x
@@ -84,6 +90,7 @@ def __init__(
8490
segmentation_channels: int = 128,
8591
dropout: float = 0.2,
8692
merge_policy: Literal["add", "cat"] = "add",
93+
interpolation_mode: str = "nearest",
8794
):
8895
super().__init__()
8996

@@ -103,9 +110,9 @@ def __init__(
103110
encoder_channels = encoder_channels[: encoder_depth + 1]
104111

105112
self.p5 = nn.Conv2d(encoder_channels[0], pyramid_channels, kernel_size=1)
106-
self.p4 = FPNBlock(pyramid_channels, encoder_channels[1])
107-
self.p3 = FPNBlock(pyramid_channels, encoder_channels[2])
108-
self.p2 = FPNBlock(pyramid_channels, encoder_channels[3])
113+
self.p4 = FPNBlock(pyramid_channels, encoder_channels[1], interpolation_mode)
114+
self.p3 = FPNBlock(pyramid_channels, encoder_channels[2], interpolation_mode)
115+
self.p2 = FPNBlock(pyramid_channels, encoder_channels[3], interpolation_mode)
109116

110117
self.seg_blocks = nn.ModuleList(
111118
[

segmentation_models_pytorch/decoders/fpn/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ class FPN(SegmentationModel):
2828
decoder_merge_policy: Determines how to merge pyramid features inside FPN. Available options are **add**
2929
and **cat**
3030
decoder_dropout: Spatial dropout rate in range (0, 1) for feature pyramid in FPN_
31+
decoder_interpolation: Interpolation mode used in decoder of the model. Available options are
32+
**"nearest"**, **"bilinear"**, **"bicubic"**, **"area"**, **"nearest-exact"**. Default is **"nearest"**.
3133
in_channels: A number of input channels for the model, default is 3 (RGB images)
3234
classes: A number of classes for output mask (or you can think as a number of channels of output mask)
3335
activation: An activation function to apply after the final convolution layer.
@@ -61,6 +63,7 @@ def __init__(
6163
decoder_segmentation_channels: int = 128,
6264
decoder_merge_policy: str = "add",
6365
decoder_dropout: float = 0.2,
66+
decoder_interpolation: str = "nearest",
6467
in_channels: int = 3,
6568
classes: int = 1,
6669
activation: Optional[str] = None,
@@ -91,6 +94,7 @@ def __init__(
9194
segmentation_channels=decoder_segmentation_channels,
9295
dropout=decoder_dropout,
9396
merge_policy=decoder_merge_policy,
97+
interpolation_mode=decoder_interpolation,
9498
)
9599

96100
self.segmentation_head = SegmentationHead(

0 commit comments

Comments
 (0)