Skip to content

Commit 58a0a8f

Browse files
committed
Modify Jaccard, Dice and Tversky losses
1 parent 41a6fe5 commit 58a0a8f

File tree

1 file changed

+21
-22
lines changed

1 file changed

+21
-22
lines changed

segmentation_models_pytorch/losses/_functional.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -157,15 +157,7 @@ def soft_jaccard_score(
157157
dims=None,
158158
) -> torch.Tensor:
159159
assert output.size() == target.size()
160-
if dims is not None:
161-
intersection = torch.sum(output * target, dim=dims)
162-
cardinality = torch.sum(output + target, dim=dims)
163-
else:
164-
intersection = torch.sum(output * target)
165-
cardinality = torch.sum(output + target)
166-
167-
union = cardinality - intersection
168-
jaccard_score = (intersection + smooth) / (union + smooth).clamp_min(eps)
160+
jaccard_score = soft_tversky_score(output, target, 1.0, 1.0, smooth, eps, dims)
169161
return jaccard_score
170162

171163

@@ -177,13 +169,7 @@ def soft_dice_score(
177169
dims=None,
178170
) -> torch.Tensor:
179171
assert output.size() == target.size()
180-
if dims is not None:
181-
intersection = torch.sum(output * target, dim=dims)
182-
cardinality = torch.sum(output + target, dim=dims)
183-
else:
184-
intersection = torch.sum(output * target)
185-
cardinality = torch.sum(output + target)
186-
dice_score = (2.0 * intersection + smooth) / (cardinality + smooth).clamp_min(eps)
172+
dice_score = soft_tversky_score(output, target, 0.5, 0.5, smooth, eps, dims)
187173
return dice_score
188174

189175

@@ -196,15 +182,28 @@ def soft_tversky_score(
196182
eps: float = 1e-7,
197183
dims=None,
198184
) -> torch.Tensor:
185+
"""Tversky loss
186+
187+
References:
188+
https://arxiv.org/pdf/2302.05666
189+
https://arxiv.org/pdf/2303.16296
190+
191+
"""
199192
assert output.size() == target.size()
200193
if dims is not None:
201-
intersection = torch.sum(output * target, dim=dims) # TP
202-
fp = torch.sum(output * (1.0 - target), dim=dims)
203-
fn = torch.sum((1 - output) * target, dim=dims)
194+
difference = torch.norm(output - target, p=1, dim=dims)
195+
output_sum = torch.sum(output, dim=dims)
196+
target_sum = torch.sum(target, dim=dims)
197+
intersection = (output_sum + target_sum - difference) / 2 # TP
198+
fp = output_sum - intersection
199+
fn = target_sum - intersection
204200
else:
205-
intersection = torch.sum(output * target) # TP
206-
fp = torch.sum(output * (1.0 - target))
207-
fn = torch.sum((1 - output) * target)
201+
difference = torch.norm(output - target, p=1)
202+
output_sum = torch.sum(output)
203+
target_sum = torch.sum(target)
204+
intersection = (output_sum + target_sum - difference) / 2 # TP
205+
fp = output_sum - intersection
206+
fn = target_sum - intersection
208207

209208
tversky_score = (intersection + smooth) / (
210209
intersection + alpha * fp + beta * fn + smooth

0 commit comments

Comments
 (0)