Skip to content

Commit d121fec

Browse files
committed
Add depth validation
1 parent 5bbb1db commit d121fec

File tree

16 files changed

+84
-1
lines changed

16 files changed

+84
-1
lines changed

segmentation_models_pytorch/encoders/densenet.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,13 @@
3232

3333
class DenseNetEncoder(DenseNet, EncoderMixin):
3434
def __init__(self, out_channels, depth=5, output_stride=32, **kwargs):
35+
if depth > 5 or depth < 1:
36+
raise ValueError(
37+
f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}"
38+
)
39+
3540
super().__init__(**kwargs)
41+
3642
self._depth = depth
3743
self._in_channels = 3
3844
self._out_channels = out_channels

segmentation_models_pytorch/encoders/dpn.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@ def __init__(
4444
output_stride: int = 32,
4545
**kwargs,
4646
):
47+
if depth > 5 or depth < 1:
48+
raise ValueError(
49+
f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}"
50+
)
51+
4752
super().__init__(**kwargs)
4853
self._stage_idxs = stage_idxs
4954
self._depth = depth

segmentation_models_pytorch/encoders/efficientnet.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ def __init__(
4343
depth: int = 5,
4444
output_stride: int = 32,
4545
):
46+
if depth > 5 or depth < 1:
47+
raise ValueError(
48+
f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}"
49+
)
50+
4651
blocks_args, global_params = get_model_params(model_name, override_params=None)
4752
super().__init__(blocks_args, global_params)
4853

segmentation_models_pytorch/encoders/inceptionresnetv2.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ def __init__(
3939
output_stride: int = 32,
4040
**kwargs,
4141
):
42+
if depth > 5 or depth < 1:
43+
raise ValueError(
44+
f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}"
45+
)
46+
4247
super().__init__(**kwargs)
4348

4449
self._depth = depth

segmentation_models_pytorch/encoders/inceptionv4.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ def __init__(
4040
output_stride: int = 32,
4141
**kwargs,
4242
):
43+
if depth > 5 or depth < 1:
44+
raise ValueError(
45+
f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}"
46+
)
4347
super().__init__(**kwargs)
4448

4549
self._depth = depth

segmentation_models_pytorch/encoders/mix_transformer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,10 @@ class MixVisionTransformerEncoder(MixVisionTransformer, EncoderMixin):
529529
def __init__(
530530
self, out_channels: List[int], depth: int = 5, output_stride: int = 32, **kwargs
531531
):
532+
if depth > 5 or depth < 1:
533+
raise ValueError(
534+
f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}"
535+
)
532536
super().__init__(**kwargs)
533537

534538
self._depth = depth

segmentation_models_pytorch/encoders/mobilenet.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ class MobileNetV2Encoder(torchvision.models.MobileNetV2, EncoderMixin):
3434
def __init__(
3535
self, out_channels: List[int], depth: int = 5, output_stride: int = 32, **kwargs
3636
):
37+
if depth > 5 or depth < 1:
38+
raise ValueError(
39+
f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}"
40+
)
3741
super().__init__(**kwargs)
3842

3943
self._depth = depth

segmentation_models_pytorch/encoders/mobileone.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,11 @@ def __init__(
319319
:param use_se: Whether to use SE-ReLU activations.
320320
:param num_conv_branches: Number of linear conv branches.
321321
"""
322+
if depth > 5 or depth < 1:
323+
raise ValueError(
324+
f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}"
325+
)
326+
322327
super().__init__()
323328

324329
assert len(width_multipliers) == 4

segmentation_models_pytorch/encoders/resnet.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ class ResNetEncoder(ResNet, EncoderMixin):
3838
def __init__(
3939
self, out_channels: List[int], depth: int = 5, output_stride: int = 32, **kwargs
4040
):
41+
if depth > 5 or depth < 1:
42+
raise ValueError(
43+
f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}"
44+
)
4145
super().__init__(**kwargs)
4246

4347
self._depth = depth

segmentation_models_pytorch/encoders/senet.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ def __init__(
4343
output_stride: int = 32,
4444
**kwargs,
4545
):
46+
if depth > 5 or depth < 1:
47+
raise ValueError(
48+
f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}"
49+
)
4650
super().__init__(**kwargs)
4751

4852
self._depth = depth

0 commit comments

Comments
 (0)