You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: segmentation_models_pytorch/decoders/deeplabv3/model.py
+30-6Lines changed: 30 additions & 6 deletions
Original file line number
Diff line number
Diff line change
@@ -35,15 +35,16 @@ class DeepLabV3(SegmentationModel):
35
35
Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
36
36
**callable** and **None**.
37
37
Default is **None**
38
-
upsampling: Final upsampling factor (should have the same value as ``encoder_output_stride`` to preserve input-output spatial shape identity).
38
+
upsampling: Final upsampling factor. Default is **None** to preserve input-output spatial shape identity
39
39
aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
40
40
on top of encoder if **aux_params** is not **None** (default). Supported params:
41
41
- classes (int): A number of classes
42
42
- pooling (str): One of "max", "avg". Default is "avg"
43
43
- dropout (float): Dropout factor in [0, 1)
44
44
- activation (str): An activation function to apply "sigmoid"/"softmax"
45
45
(could be **None** to return logits)
46
-
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.
46
+
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models.
47
+
Keys with ``None`` values are pruned before passing.
47
48
48
49
Returns:
49
50
``torch.nn.Module``: **DeepLabV3**
@@ -72,6 +73,13 @@ def __init__(
72
73
):
73
74
super().__init__()
74
75
76
+
ifencoder_output_stridenotin [8, 16]:
77
+
raiseValueError(
78
+
"DeeplabV3 support output stride 8 or 16, got {}.".format(
@@ -138,7 +154,8 @@ class DeepLabV3Plus(SegmentationModel):
138
154
- dropout (float): Dropout factor in [0, 1)
139
155
- activation (str): An activation function to apply "sigmoid"/"softmax"
140
156
(could be **None** to return logits)
141
-
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.
157
+
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models.
158
+
Keys with ``None`` values are pruned before passing.
142
159
143
160
Returns:
144
161
``torch.nn.Module``: **DeepLabV3Plus**
@@ -167,6 +184,13 @@ def __init__(
167
184
):
168
185
super().__init__()
169
186
187
+
ifencoder_output_stridenotin [8, 16]:
188
+
raiseValueError(
189
+
"DeeplabV3Plus support output stride 8 or 16, got {}.".format(
0 commit comments