Skip to content

Commit c0170ac

Browse files
authored
ignore_index now works in DiceLoss (#366)
1 parent c606950 commit c0170ac

File tree

1 file changed

+21
-3
lines changed
  • segmentation_models_pytorch/losses

1 file changed

+21
-3
lines changed

segmentation_models_pytorch/losses/dice.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def __init__(
5353
self.smooth = smooth
5454
self.eps = eps
5555
self.log_loss = log_loss
56+
self.ignore_index = ignore_index
5657

5758
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
5859

@@ -75,17 +76,34 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
7576
y_true = y_true.view(bs, 1, -1)
7677
y_pred = y_pred.view(bs, 1, -1)
7778

79+
if self.ignore_index is not None:
80+
mask = y_true != self.ignore_index
81+
y_pred = y_pred * mask
82+
y_true = y_true * mask
83+
7884
if self.mode == MULTICLASS_MODE:
7985
y_true = y_true.view(bs, -1)
8086
y_pred = y_pred.view(bs, num_classes, -1)
8187

82-
y_true = F.one_hot(y_true, num_classes) # N,H*W -> N,H*W, C
83-
y_true = y_true.permute(0, 2, 1) # H, C, H*W
88+
if self.ignore_index is not None:
89+
mask = y_true != self.ignore_index
90+
y_pred = y_pred * mask.unsqueeze(1)
91+
92+
y_true = F.one_hot((y_true * mask).to(torch.long), num_classes) # N,H*W -> N,H*W, C
93+
y_true = y_true.permute(0, 2, 1) * mask.unsqueeze(1) # H, C, H*W
94+
else:
95+
y_true = F.one_hot(y_true, num_classes) # N,H*W -> N,H*W, C
96+
y_true = y_true.permute(0, 2, 1) # H, C, H*W
8497

8598
if self.mode == MULTILABEL_MODE:
8699
y_true = y_true.view(bs, num_classes, -1)
87100
y_pred = y_pred.view(bs, num_classes, -1)
88101

102+
if self.ignore_index is not None:
103+
mask = y_true != self.ignore_index
104+
y_pred = y_pred * mask
105+
y_true = y_true * mask
106+
89107
scores = soft_dice_score(y_pred, y_true.type_as(y_pred), smooth=self.smooth, eps=self.eps, dims=dims)
90108

91109
if self.log_loss:
@@ -104,4 +122,4 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
104122
if self.classes is not None:
105123
loss = loss[self.classes]
106124

107-
return loss.mean()
125+
return loss.mean()

0 commit comments

Comments
 (0)