Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions segmentation_models_pytorch/losses/jaccard.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(
log_loss: bool = False,
from_logits: bool = True,
smooth: float = 0.0,
ignore_index: Optional[int] = None,
eps: float = 1e-7,
):
"""Jaccard loss for image segmentation task.
Expand Down Expand Up @@ -51,6 +52,7 @@ def __init__(
self.classes = classes
self.from_logits = from_logits
self.smooth = smooth
self.ignore_index = ignore_index
self.eps = eps
self.log_loss = log_loss

Expand All @@ -74,17 +76,36 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
y_true = y_true.view(bs, 1, -1)
y_pred = y_pred.view(bs, 1, -1)

if self.ignore_index is not None:
mask = y_true != self.ignore_index
y_pred = y_pred * mask
y_true = y_true * mask

if self.mode == MULTICLASS_MODE:
y_true = y_true.view(bs, -1)
y_pred = y_pred.view(bs, num_classes, -1)

y_true = F.one_hot(y_true, num_classes) # N,H*W -> N,H*W, C
y_true = y_true.permute(0, 2, 1) # H, C, H*W
if self.ignore_index is not None:
mask = y_true != self.ignore_index
y_pred = y_pred * mask.unsqueeze(1)

y_true = F.one_hot(
(y_true * mask).to(torch.long), num_classes
) # N,H*W -> N,H*W, C
y_true = y_true.permute(0, 2, 1) * mask.unsqueeze(1) # N, C, H*W
else:
y_true = F.one_hot(y_true, num_classes) # N,H*W -> N,H*W, C
y_true = y_true.permute(0, 2, 1) # N, C, H*W

if self.mode == MULTILABEL_MODE:
y_true = y_true.view(bs, num_classes, -1)
y_pred = y_pred.view(bs, num_classes, -1)

if self.ignore_index is not None:
mask = y_true != self.ignore_index
y_pred = y_pred * mask
y_true = y_true * mask

scores = soft_jaccard_score(
y_pred,
y_true.type(y_pred.dtype),
Expand Down
Loading