Skip to content

Commit 66bc2c8

Browse files
committed
Fix loading from config
1 parent eadfe1f commit 66bc2c8

File tree

11 files changed

+36
-3
lines changed

11 files changed

+36
-3
lines changed

segmentation_models_pytorch/base/hub_mixin.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,3 +136,15 @@ def from_pretrained(pretrained_model_name_or_path: str, *args, **kwargs):
136136

137137
model_class = getattr(smp, model_class_name)
138138
return model_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
139+
140+
141+
def supports_config_loading(func):
142+
"""Decorator to filter special config kwargs"""
143+
144+
@wraps(func)
145+
def wrapper(self, *args, **kwargs):
146+
kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
147+
return func(self, *args, **kwargs)
148+
149+
return wrapper
150+
from segmentation_models_pytorch.base.hub_mixin import supports_config_loading

segmentation_models_pytorch/decoders/deeplabv3/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
SegmentationModel,
99
)
1010
from segmentation_models_pytorch.encoders import get_encoder
11+
from segmentation_models_pytorch.base.hub_mixin import supports_config_loading
1112

1213
from .decoder import DeepLabV3Decoder, DeepLabV3PlusDecoder
1314

@@ -54,6 +55,7 @@ class DeepLabV3(SegmentationModel):
5455
5556
"""
5657

58+
@supports_config_loading
5759
def __init__(
5860
self,
5961
encoder_name: str = "resnet34",
@@ -163,6 +165,7 @@ class DeepLabV3Plus(SegmentationModel):
163165
164166
"""
165167

168+
@supports_config_loading
166169
def __init__(
167170
self,
168171
encoder_name: str = "resnet34",

segmentation_models_pytorch/decoders/fpn/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
SegmentationModel,
77
)
88
from segmentation_models_pytorch.encoders import get_encoder
9+
from segmentation_models_pytorch.base.hub_mixin import supports_config_loading
910

1011
from .decoder import FPNDecoder
1112

@@ -51,6 +52,7 @@ class FPN(SegmentationModel):
5152
5253
"""
5354

55+
@supports_config_loading
5456
def __init__(
5557
self,
5658
encoder_name: str = "resnet34",

segmentation_models_pytorch/decoders/linknet/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
SegmentationModel,
77
)
88
from segmentation_models_pytorch.encoders import get_encoder
9+
from segmentation_models_pytorch.base.hub_mixin import supports_config_loading
910

1011
from .decoder import LinknetDecoder
1112

@@ -53,6 +54,7 @@ class Linknet(SegmentationModel):
5354
https://arxiv.org/abs/1707.03718
5455
"""
5556

57+
@supports_config_loading
5658
def __init__(
5759
self,
5860
encoder_name: str = "resnet34",

segmentation_models_pytorch/decoders/manet/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
SegmentationModel,
77
)
88
from segmentation_models_pytorch.encoders import get_encoder
9+
from segmentation_models_pytorch.base.hub_mixin import supports_config_loading
910

1011
from .decoder import MAnetDecoder
1112

@@ -56,6 +57,7 @@ class MAnet(SegmentationModel):
5657
5758
"""
5859

60+
@supports_config_loading
5961
def __init__(
6062
self,
6163
encoder_name: str = "resnet34",

segmentation_models_pytorch/decoders/pan/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
SegmentationModel,
77
)
88
from segmentation_models_pytorch.encoders import get_encoder
9+
from segmentation_models_pytorch.base.hub_mixin import supports_config_loading
910

1011
from .decoder import PANDecoder
1112

@@ -53,6 +54,7 @@ class PAN(SegmentationModel):
5354
5455
"""
5556

57+
@supports_config_loading
5658
def __init__(
5759
self,
5860
encoder_name: str = "resnet34",

segmentation_models_pytorch/decoders/pspnet/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
SegmentationModel,
77
)
88
from segmentation_models_pytorch.encoders import get_encoder
9+
from segmentation_models_pytorch.base.hub_mixin import supports_config_loading
910

1011
from .decoder import PSPDecoder
1112

@@ -54,6 +55,7 @@ class PSPNet(SegmentationModel):
5455
https://arxiv.org/abs/1612.01105
5556
"""
5657

58+
@supports_config_loading
5759
def __init__(
5860
self,
5961
encoder_name: str = "resnet34",

segmentation_models_pytorch/decoders/segformer/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
SegmentationModel,
77
)
88
from segmentation_models_pytorch.encoders import get_encoder
9+
from segmentation_models_pytorch.base.hub_mixin import supports_config_loading
910

1011
from .decoder import SegformerDecoder
1112

@@ -46,6 +47,7 @@ class Segformer(SegmentationModel):
4647
4748
"""
4849

50+
@supports_config_loading
4951
def __init__(
5052
self,
5153
encoder_name: str = "resnet34",

segmentation_models_pytorch/decoders/unet/model.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
from typing import Any, List, Optional, Union
1+
from typing import Any, Optional, Union, Tuple, Callable
22

33
from segmentation_models_pytorch.base import (
44
ClassificationHead,
55
SegmentationHead,
66
SegmentationModel,
77
)
88
from segmentation_models_pytorch.encoders import get_encoder
9+
from segmentation_models_pytorch.base.hub_mixin import supports_config_loading
910

1011
from .decoder import UnetDecoder
1112

@@ -55,17 +56,18 @@ class Unet(SegmentationModel):
5556
5657
"""
5758

59+
@supports_config_loading
5860
def __init__(
5961
self,
6062
encoder_name: str = "resnet34",
6163
encoder_depth: int = 5,
6264
encoder_weights: Optional[str] = "imagenet",
6365
decoder_use_batchnorm: bool = True,
64-
decoder_channels: List[int] = (256, 128, 64, 32, 16),
66+
decoder_channels: Tuple[int, ...] = (256, 128, 64, 32, 16),
6567
decoder_attention_type: Optional[str] = None,
6668
in_channels: int = 3,
6769
classes: int = 1,
68-
activation: Optional[Union[str, callable]] = None,
70+
activation: Optional[Union[str, Callable]] = None,
6971
aux_params: Optional[dict] = None,
7072
**kwargs: dict[str, Any],
7173
):

segmentation_models_pytorch/decoders/unetplusplus/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
SegmentationModel,
77
)
88
from segmentation_models_pytorch.encoders import get_encoder
9+
from segmentation_models_pytorch.base.hub_mixin import supports_config_loading
910

1011
from .decoder import UnetPlusPlusDecoder
1112

@@ -55,6 +56,7 @@ class UnetPlusPlus(SegmentationModel):
5556
5657
"""
5758

59+
@supports_config_loading
5860
def __init__(
5961
self,
6062
encoder_name: str = "resnet34",

0 commit comments

Comments
 (0)