@@ -17,6 +17,7 @@ def __init__(
1717 log_loss : bool = False ,
1818 from_logits : bool = True ,
1919 smooth : float = 0.0 ,
20+ ignore_index : Optional [int ] = None ,
2021 eps : float = 1e-7 ,
2122 ):
2223 """Jaccard loss for image segmentation task.
@@ -51,6 +52,7 @@ def __init__(
5152 self .classes = classes
5253 self .from_logits = from_logits
5354 self .smooth = smooth
55+ self .ignore_index = ignore_index
5456 self .eps = eps
5557 self .log_loss = log_loss
5658
@@ -74,17 +76,36 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
7476 y_true = y_true .view (bs , 1 , - 1 )
7577 y_pred = y_pred .view (bs , 1 , - 1 )
7678
79+ if self .ignore_index is not None :
80+ mask = y_true != self .ignore_index
81+ y_pred = y_pred * mask
82+ y_true = y_true * mask
83+
7784 if self .mode == MULTICLASS_MODE :
7885 y_true = y_true .view (bs , - 1 )
7986 y_pred = y_pred .view (bs , num_classes , - 1 )
8087
81- y_true = F .one_hot (y_true , num_classes ) # N,H*W -> N,H*W, C
82- y_true = y_true .permute (0 , 2 , 1 ) # H, C, H*W
88+ if self .ignore_index is not None :
89+ mask = y_true != self .ignore_index
90+ y_pred = y_pred * mask .unsqueeze (1 )
91+
92+ y_true = F .one_hot (
93+ (y_true * mask ).to (torch .long ), num_classes
94+ ) # N,H*W -> N,H*W, C
95+ y_true = y_true .permute (0 , 2 , 1 ) * mask .unsqueeze (1 ) # N, C, H*W
96+ else :
97+ y_true = F .one_hot (y_true , num_classes ) # N,H*W -> N,H*W, C
98+ y_true = y_true .permute (0 , 2 , 1 ) # N, C, H*W
8399
84100 if self .mode == MULTILABEL_MODE :
85101 y_true = y_true .view (bs , num_classes , - 1 )
86102 y_pred = y_pred .view (bs , num_classes , - 1 )
87103
104+ if self .ignore_index is not None :
105+ mask = y_true != self .ignore_index
106+ y_pred = y_pred * mask
107+ y_true = y_true * mask
108+
88109 scores = soft_jaccard_score (
89110 y_pred ,
90111 y_true .type (y_pred .dtype ),
0 commit comments