Skip to content

Commit a3d4487

Browse files
authored
Add semantic segmentation popular losses (#317)
1 parent 04f59db commit a3d4487

File tree

15 files changed

+1194
-1
lines changed

15 files changed

+1194
-1
lines changed

docker/Dockerfile.dev

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
FROM anibali/pytorch:no-cuda
1+
FROM anibali/pytorch:1.5.0-nocuda
22

33
WORKDIR /tmp/smp/
44

docs/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def get_version():
9393
autodoc_mock_imports = [
9494
'torch',
9595
'tqdm',
96+
'numpy',
9697
'timm',
9798
'pretrainedmodels',
9899
'torchvision',

docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ Welcome to Segmentation Models's documentation!
1414
quickstart
1515
models
1616
encoders
17+
losses
1718
insights
1819

1920

docs/losses.rst

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
📉 Losses
2+
=========
3+
4+
Collection of popular semantic segmentation losses. Adapted from
5+
an awesome repo with pytorch utils https://github.com/BloodAxe/pytorch-toolbelt
6+
7+
Constants
8+
~~~~~~~~~
9+
.. automodule:: segmentation_models_pytorch.losses.constants
10+
:members:
11+
12+
JaccardLoss
13+
~~~~~~~~~~~
14+
.. autoclass:: segmentation_models_pytorch.losses.JaccardLoss
15+
16+
DiceLoss
17+
~~~~~~~~
18+
.. autoclass:: segmentation_models_pytorch.losses.DiceLoss
19+
20+
FocalLoss
21+
~~~~~~~~~
22+
.. autoclass:: segmentation_models_pytorch.losses.FocalLoss
23+
24+
LovaszLoss
25+
~~~~~~~~~~
26+
.. autoclass:: segmentation_models_pytorch.losses.LovaszLoss
27+
28+
SoftBCEWithLogitsLoss
29+
~~~~~~~~~~~~~~~~~~~~~
30+
.. autoclass:: segmentation_models_pytorch.losses.SoftBCEWithLogitsLoss
31+
32+
SoftCrossEntropyLoss
33+
~~~~~~~~~~~~~~~~~~~~
34+
.. autoclass:: segmentation_models_pytorch.losses.SoftCrossEntropyLoss

segmentation_models_pytorch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from . import encoders
1111
from . import utils
12+
from . import losses
1213

1314
from .__version__ import __version__
1415

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from .constants import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE
2+
3+
from .jaccard import JaccardLoss
4+
from .dice import DiceLoss
5+
from .focal import FocalLoss
6+
from .lovasz import LovaszLoss
7+
from .soft_bce import SoftBCEWithLogitsLoss
8+
from .soft_ce import SoftCrossEntropyLoss
Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
import math
2+
import numpy as np
3+
4+
from typing import Optional
5+
6+
import torch
7+
import torch.nn.functional as F
8+
9+
10+
__all__ = [
11+
"focal_loss_with_logits",
12+
"softmax_focal_loss_with_logits",
13+
"soft_jaccard_score",
14+
"soft_dice_score",
15+
"wing_loss",
16+
]
17+
18+
19+
def to_tensor(x, dtype=None) -> torch.Tensor:
20+
if isinstance(x, torch.Tensor):
21+
if dtype is not None:
22+
x = x.type(dtype)
23+
return x
24+
if isinstance(x, np.ndarray):
25+
x = torch.from_numpy(x)
26+
if dtype is not None:
27+
x = x.type(dtype)
28+
return x
29+
if isinstance(x, (list, tuple)):
30+
x = np.ndarray(x)
31+
x = torch.from_numpy(x)
32+
if dtype is not None:
33+
x = x.type(dtype)
34+
return x
35+
36+
37+
def focal_loss_with_logits(
38+
output: torch.Tensor,
39+
target: torch.Tensor,
40+
gamma: float = 2.0,
41+
alpha: Optional[float] = 0.25,
42+
reduction: str = "mean",
43+
normalized: bool = False,
44+
reduced_threshold: Optional[float] = None,
45+
eps: float = 1e-6,
46+
) -> torch.Tensor:
47+
"""Compute binary focal loss between target and output logits.
48+
See :class:`~pytorch_toolbelt.losses.FocalLoss` for details.
49+
50+
Args:
51+
output: Tensor of arbitrary shape (predictions of the model)
52+
target: Tensor of the same shape as input
53+
gamma: Focal loss power factor
54+
alpha: Weight factor to balance positive and negative samples. Alpha must be in [0...1] range,
55+
high values will give more weight to positive class.
56+
reduction (string, optional): Specifies the reduction to apply to the output:
57+
'none' | 'mean' | 'sum' | 'batchwise_mean'. 'none': no reduction will be applied,
58+
'mean': the sum of the output will be divided by the number of
59+
elements in the output, 'sum': the output will be summed. Note: :attr:`size_average`
60+
and :attr:`reduce` are in the process of being deprecated, and in the meantime,
61+
specifying either of those two args will override :attr:`reduction`.
62+
'batchwise_mean' computes mean loss per sample in batch. Default: 'mean'
63+
normalized (bool): Compute normalized focal loss (https://arxiv.org/pdf/1909.07829.pdf).
64+
reduced_threshold (float, optional): Compute reduced focal loss (https://arxiv.org/abs/1903.01347).
65+
66+
References:
67+
https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/loss/losses.py
68+
"""
69+
target = target.type(output.type())
70+
71+
logpt = F.binary_cross_entropy_with_logits(output, target, reduction="none")
72+
pt = torch.exp(-logpt)
73+
74+
# compute the loss
75+
if reduced_threshold is None:
76+
focal_term = (1.0 - pt).pow(gamma)
77+
else:
78+
focal_term = ((1.0 - pt) / reduced_threshold).pow(gamma)
79+
focal_term[pt < reduced_threshold] = 1
80+
81+
loss = focal_term * logpt
82+
83+
if alpha is not None:
84+
loss *= alpha * target + (1 - alpha) * (1 - target)
85+
86+
if normalized:
87+
norm_factor = focal_term.sum().clamp_min(eps)
88+
loss /= norm_factor
89+
90+
if reduction == "mean":
91+
loss = loss.mean()
92+
if reduction == "sum":
93+
loss = loss.sum()
94+
if reduction == "batchwise_mean":
95+
loss = loss.sum(0)
96+
97+
return loss
98+
99+
100+
def softmax_focal_loss_with_logits(
101+
output: torch.Tensor,
102+
target: torch.Tensor,
103+
gamma: float = 2.0,
104+
reduction="mean",
105+
normalized=False,
106+
reduced_threshold: Optional[float] = None,
107+
eps: float = 1e-6,
108+
) -> torch.Tensor:
109+
"""Softmax version of focal loss between target and output logits.
110+
See :class:`~pytorch_toolbelt.losses.FocalLoss` for details.
111+
112+
Args:
113+
output: Tensor of shape [B, C, *] (Similar to nn.CrossEntropyLoss)
114+
target: Tensor of shape [B, *] (Similar to nn.CrossEntropyLoss)
115+
reduction (string, optional): Specifies the reduction to apply to the output:
116+
'none' | 'mean' | 'sum' | 'batchwise_mean'. 'none': no reduction will be applied,
117+
'mean': the sum of the output will be divided by the number of
118+
elements in the output, 'sum': the output will be summed. Note: :attr:`size_average`
119+
and :attr:`reduce` are in the process of being deprecated, and in the meantime,
120+
specifying either of those two args will override :attr:`reduction`.
121+
'batchwise_mean' computes mean loss per sample in batch. Default: 'mean'
122+
normalized (bool): Compute normalized focal loss (https://arxiv.org/pdf/1909.07829.pdf).
123+
reduced_threshold (float, optional): Compute reduced focal loss (https://arxiv.org/abs/1903.01347).
124+
"""
125+
log_softmax = F.log_softmax(output, dim=1)
126+
127+
loss = F.nll_loss(log_softmax, target, reduction="none")
128+
pt = torch.exp(-loss)
129+
130+
# compute the loss
131+
if reduced_threshold is None:
132+
focal_term = (1.0 - pt).pow(gamma)
133+
else:
134+
focal_term = ((1.0 - pt) / reduced_threshold).pow(gamma)
135+
focal_term[pt < reduced_threshold] = 1
136+
137+
loss = focal_term * loss
138+
139+
if normalized:
140+
norm_factor = focal_term.sum().clamp_min(eps)
141+
loss = loss / norm_factor
142+
143+
if reduction == "mean":
144+
loss = loss.mean()
145+
if reduction == "sum":
146+
loss = loss.sum()
147+
if reduction == "batchwise_mean":
148+
loss = loss.sum(0)
149+
150+
return loss
151+
152+
153+
def soft_jaccard_score(
154+
output: torch.Tensor, target: torch.Tensor, smooth: float = 0.0, eps: float = 1e-7, dims=None
155+
) -> torch.Tensor:
156+
assert output.size() == target.size()
157+
if dims is not None:
158+
intersection = torch.sum(output * target, dim=dims)
159+
cardinality = torch.sum(output + target, dim=dims)
160+
else:
161+
intersection = torch.sum(output * target)
162+
cardinality = torch.sum(output + target)
163+
164+
union = cardinality - intersection
165+
jaccard_score = (intersection + smooth) / (union + smooth).clamp_min(eps)
166+
return jaccard_score
167+
168+
169+
def soft_dice_score(
170+
output: torch.Tensor, target: torch.Tensor, smooth: float = 0.0, eps: float = 1e-7, dims=None
171+
) -> torch.Tensor:
172+
assert output.size() == target.size()
173+
if dims is not None:
174+
intersection = torch.sum(output * target, dim=dims)
175+
cardinality = torch.sum(output + target, dim=dims)
176+
else:
177+
intersection = torch.sum(output * target)
178+
cardinality = torch.sum(output + target)
179+
dice_score = (2.0 * intersection + smooth) / (cardinality + smooth).clamp_min(eps)
180+
return dice_score
181+
182+
183+
def wing_loss(output: torch.Tensor, target: torch.Tensor, width=5, curvature=0.5, reduction="mean"):
184+
"""
185+
https://arxiv.org/pdf/1711.06753.pdf
186+
:param output:
187+
:param target:
188+
:param width:
189+
:param curvature:
190+
:param reduction:
191+
:return:
192+
"""
193+
diff_abs = (target - output).abs()
194+
loss = diff_abs.clone()
195+
196+
idx_smaller = diff_abs < width
197+
idx_bigger = diff_abs >= width
198+
199+
loss[idx_smaller] = width * torch.log(1 + diff_abs[idx_smaller] / curvature)
200+
201+
C = width - width * math.log(1 + width / curvature)
202+
loss[idx_bigger] = loss[idx_bigger] - C
203+
204+
if reduction == "sum":
205+
loss = loss.sum()
206+
207+
if reduction == "mean":
208+
loss = loss.mean()
209+
210+
return loss
211+
212+
213+
def label_smoothed_nll_loss(
214+
lprobs: torch.Tensor, target: torch.Tensor, epsilon: float, ignore_index=None, reduction="mean", dim=-1
215+
) -> torch.Tensor:
216+
"""
217+
Source: https://github.com/pytorch/fairseq/blob/master/fairseq/criterions/label_smoothed_cross_entropy.py
218+
:param lprobs: Log-probabilities of predictions (e.g after log_softmax)
219+
:param target:
220+
:param epsilon:
221+
:param ignore_index:
222+
:param reduction:
223+
:return:
224+
"""
225+
if target.dim() == lprobs.dim() - 1:
226+
target = target.unsqueeze(dim)
227+
228+
if ignore_index is not None:
229+
pad_mask = target.eq(ignore_index)
230+
target = target.masked_fill(pad_mask, 0)
231+
nll_loss = -lprobs.gather(dim=dim, index=target)
232+
smooth_loss = -lprobs.sum(dim=dim, keepdim=True)
233+
234+
# nll_loss.masked_fill_(pad_mask, 0.0)
235+
# smooth_loss.masked_fill_(pad_mask, 0.0)
236+
nll_loss = nll_loss.masked_fill(pad_mask, 0.0)
237+
smooth_loss = smooth_loss.masked_fill(pad_mask, 0.0)
238+
else:
239+
nll_loss = -lprobs.gather(dim=dim, index=target)
240+
smooth_loss = -lprobs.sum(dim=dim, keepdim=True)
241+
242+
nll_loss = nll_loss.squeeze(dim)
243+
smooth_loss = smooth_loss.squeeze(dim)
244+
245+
if reduction == "sum":
246+
nll_loss = nll_loss.sum()
247+
smooth_loss = smooth_loss.sum()
248+
if reduction == "mean":
249+
nll_loss = nll_loss.mean()
250+
smooth_loss = smooth_loss.mean()
251+
252+
eps_i = epsilon / lprobs.size(dim)
253+
loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
254+
return loss
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#: Loss binary mode suppose you are solving binary segmentation task.
2+
#: That mean yor have only one class which pixels are labled as **1**,
3+
#: the rest pixels are backgroud and labeled as **0**.
4+
#: Target mask shape - (N, H, W), model output mask shape (N, 1, H, W).
5+
BINARY_MODE: str = "binary"
6+
7+
#: Loss multilabel mode suppose you are solving multi-**label** segmentation task.
8+
#: That mean you have *C = 1..N* classes which pixels are labeled as **1**,
9+
#: classes are not mutually exclusive and each class have its own *channel*,
10+
#: pixels in each channel which are not belong to class labeled as **0**.
11+
#: Target mask shape - (N, C, H, W), model output mask shape (N, C, H, W).
12+
MULTICLASS_MODE: str = "multiclass"
13+
14+
#: Loss multiclass mode suppose you are solving multi-**class** segmentation task.
15+
#: That mean you have *C = 1..N* classes which have unique label values,
16+
#: classes are mutually exclusive and all pixels are labeled with theese values.
17+
#: Target mask shape - (N, H, W), model output mask shape (N, C, H, W).
18+
MULTILABEL_MODE: str = "multilabel"

0 commit comments

Comments
 (0)