@@ -53,6 +53,7 @@ def __init__(
53
53
self .smooth = smooth
54
54
self .eps = eps
55
55
self .log_loss = log_loss
56
+ self .ignore_index = ignore_index
56
57
57
58
def forward (self , y_pred : torch .Tensor , y_true : torch .Tensor ) -> torch .Tensor :
58
59
@@ -75,17 +76,34 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
75
76
y_true = y_true .view (bs , 1 , - 1 )
76
77
y_pred = y_pred .view (bs , 1 , - 1 )
77
78
79
+ if self .ignore_index is not None :
80
+ mask = y_true != self .ignore_index
81
+ y_pred = y_pred * mask
82
+ y_true = y_true * mask
83
+
78
84
if self .mode == MULTICLASS_MODE :
79
85
y_true = y_true .view (bs , - 1 )
80
86
y_pred = y_pred .view (bs , num_classes , - 1 )
81
87
82
- y_true = F .one_hot (y_true , num_classes ) # N,H*W -> N,H*W, C
83
- y_true = y_true .permute (0 , 2 , 1 ) # H, C, H*W
88
+ if self .ignore_index is not None :
89
+ mask = y_true != self .ignore_index
90
+ y_pred = y_pred * mask .unsqueeze (1 )
91
+
92
+ y_true = F .one_hot ((y_true * mask ).to (torch .long ), num_classes ) # N,H*W -> N,H*W, C
93
+ y_true = y_true .permute (0 , 2 , 1 ) * mask .unsqueeze (1 ) # H, C, H*W
94
+ else :
95
+ y_true = F .one_hot (y_true , num_classes ) # N,H*W -> N,H*W, C
96
+ y_true = y_true .permute (0 , 2 , 1 ) # H, C, H*W
84
97
85
98
if self .mode == MULTILABEL_MODE :
86
99
y_true = y_true .view (bs , num_classes , - 1 )
87
100
y_pred = y_pred .view (bs , num_classes , - 1 )
88
101
102
+ if self .ignore_index is not None :
103
+ mask = y_true != self .ignore_index
104
+ y_pred = y_pred * mask
105
+ y_true = y_true * mask
106
+
89
107
scores = soft_dice_score (y_pred , y_true .type_as (y_pred ), smooth = self .smooth , eps = self .eps , dims = dims )
90
108
91
109
if self .log_loss :
@@ -104,4 +122,4 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
104
122
if self .classes is not None :
105
123
loss = loss [self .classes ]
106
124
107
- return loss .mean ()
125
+ return loss .mean ()
0 commit comments