Skip to content

Commit 56efd39

Browse files
authored
Update lovasz.py (#378)
1 parent 33dc950 commit 56efd39

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

segmentation_models_pytorch/losses/lovasz.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
import torch
1010
import torch.nn.functional as F
11-
from torch.autograd import Variable
1211
from torch.nn.modules.loss import _Loss
1312
from .constants import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE
1413

@@ -37,7 +36,7 @@ def _lovasz_grad(gt_sorted):
3736
def _lovasz_hinge(logits, labels, per_image=True, ignore=None):
3837
"""
3938
Binary Lovasz hinge loss
40-
logits: [B, H, W] Variable, logits at each pixel (between -infinity and +infinity)
39+
logits: [B, H, W] Logits at each pixel (between -infinity and +infinity)
4140
labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
4241
per_image: compute the loss per image instead of per batch
4342
ignore: void class id
@@ -55,20 +54,20 @@ def _lovasz_hinge(logits, labels, per_image=True, ignore=None):
5554
def _lovasz_hinge_flat(logits, labels):
5655
"""Binary Lovasz hinge loss
5756
Args:
58-
logits: [P] Variable, logits at each prediction (between -infinity and +infinity)
57+
logits: [P] Logits at each prediction (between -infinity and +infinity)
5958
labels: [P] Tensor, binary ground truth labels (0 or 1)
6059
ignore: label to ignore
6160
"""
6261
if len(labels) == 0:
6362
# only void pixels, the gradients should be 0
6463
return logits.sum() * 0.0
6564
signs = 2.0 * labels.float() - 1.0
66-
errors = 1.0 - logits * Variable(signs)
65+
errors = 1.0 - logits * signs
6766
errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
6867
perm = perm.data
6968
gt_sorted = labels[perm]
7069
grad = _lovasz_grad(gt_sorted)
71-
loss = torch.dot(F.relu(errors_sorted), Variable(grad))
70+
loss = torch.dot(F.relu(errors_sorted), grad)
7271
return loss
7372

7473

@@ -92,7 +91,7 @@ def _flatten_binary_scores(scores, labels, ignore=None):
9291
def _lovasz_softmax(probas, labels, classes="present", per_image=False, ignore=None):
9392
"""Multi-class Lovasz-Softmax loss
9493
Args:
95-
@param probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1).
94+
@param probas: [B, C, H, W] Class probabilities at each prediction (between 0 and 1).
9695
Interpreted as binary (sigmoid) output with outputs of size [B, H, W].
9796
@param labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
9897
@param classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
@@ -112,7 +111,7 @@ def _lovasz_softmax(probas, labels, classes="present", per_image=False, ignore=N
112111
def _lovasz_softmax_flat(probas, labels, classes="present"):
113112
"""Multi-class Lovasz-Softmax loss
114113
Args:
115-
@param probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)
114+
@param probas: [P, C] Class probabilities at each prediction (between 0 and 1)
116115
@param labels: [P] Tensor, ground truth labels (between 0 and C - 1)
117116
@param classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
118117
"""
@@ -225,4 +224,4 @@ def forward(self, y_pred, y_true):
225224
loss = _lovasz_softmax(y_pred, y_true, per_image=self.per_image, ignore=self.ignore_index)
226225
else:
227226
raise ValueError("Wrong mode {}.".format(self.mode))
228-
return loss
227+
return loss

0 commit comments

Comments
 (0)