Skip to content

Commit 9e33cfa

Browse files
committed
Docstrings updated
1 parent f1f42ff commit 9e33cfa

File tree

5 files changed

+23
-8
lines changed

5 files changed

+23
-8
lines changed

segmentation_models_pytorch/base/encoder_decoder.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,33 @@ def __init__(self, encoder, decoder, activation):
99
super().__init__()
1010
self.encoder = encoder
1111
self.decoder = decoder
12-
12+
1313
if callable(activation) or activation is None:
1414
self.activation = activation
1515
elif activation == 'softmax':
1616
self.activation = nn.Softmax(dim=1)
1717
elif activation == 'sigmoid':
1818
self.activation = nn.Sigmoid()
1919
else:
20-
raise ValueError('Activation should be "sigmoid" or "softmax"')
21-
20+
raise ValueError('Activation should be "sigmoid"/"softmax"/callable/None')
21+
2222
def forward(self, x):
23+
"""Sequentially pass `x` trough model`s `encoder` and `decoder` (return logits!)"""
2324
x = self.encoder(x)
2425
x = self.decoder(x)
2526
return x
2627

2728
def predict(self, x):
29+
"""Inference method. Switch model to `eval` mode, call `.forward(x)`
30+
and apply activation function (if activation is not `None`) with `torch.no_grad()`
31+
32+
Args:
33+
x: 4D torch tensor with shape (batch_size, channels, height, width)
34+
35+
Return:
36+
prediction: 4D torch tensor with shape (batch_size, classes, height, width)
37+
38+
"""
2839
if self.training:
2940
self.eval()
3041

segmentation_models_pytorch/fpn/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@ class FPN(EncoderDecoder):
1313
decoder_segmentation_channels: a number of convolution filters in segmentation head of FPN_.
1414
classes: a number of classes for output (output shape - ``(batch, classes, h, w)``).
1515
dropout: spatial dropout rate in range (0, 1).
16-
activation: one of [``sigmoid``, ``softmax``, None]
16+
activation: activation function used in ``.predict(x)`` method for inference.
17+
One of [``sigmoid``, ``softmax``, callable, None]
1718
1819
Returns:
19-
``keras.models.Model``: **FPN**
20+
``torch.nn.Module``: **FPN**
2021
2122
.. _FPN:
2223
http://presentations.cocodataset.org/COCO17-Stuff-FAIR.pdf

segmentation_models_pytorch/linknet/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ class Linknet(EncoderDecoder):
1616
decoder_use_batchnorm: if ``True``, ``BatchNormalisation`` layer between ``Conv2D`` and ``Activation`` layers
1717
is used.
1818
classes: a number of classes for output (output shape - ``(batch, classes, h, w)``).
19-
activation: one of [``sigmoid``, ``softmax``, None]
19+
activation: activation function used in ``.predict(x)`` method for inference.
20+
One of [``sigmoid``, ``softmax``, callable, None]
2021
2122
Returns:
2223
``torch.nn.Module``: **Linknet**

segmentation_models_pytorch/pspnet/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ class PSPNet(EncoderDecoder):
1818
psp_aux_output: if ``True`` add auxiliary classification output for encoder training
1919
psp_dropout: spatial dropout rate between 0 and 1.
2020
classes: a number of classes for output (output shape - ``(batch, classes, h, w)``).
21-
activation: one of [``sigmoid``, ``softmax``, None]
21+
activation: activation function used in ``.predict(x)`` method for inference.
22+
One of [``sigmoid``, ``softmax``, callable, None]
2223
2324
Returns:
2425
``torch.nn.Module``: **PSPNet**

segmentation_models_pytorch/unet/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ class Unet(EncoderDecoder):
1414
decoder_use_batchnorm: if ``True``, ``BatchNormalisation`` layer between ``Conv2D`` and ``Activation`` layers
1515
is used.
1616
classes: a number of classes for output (output shape - ``(batch, classes, h, w)``).
17-
activation: one of [``sigmoid``, ``softmax``, None]
17+
activation: activation function used in ``.predict(x)`` method for inference.
18+
One of [``sigmoid``, ``softmax``, callable, None]
1819
center: if ``True`` add ``Conv2dReLU`` block on encoder head (useful for VGG models)
1920
2021
Returns:

0 commit comments

Comments
 (0)