Skip to content

Commit c4cbd1e

Browse files
committed
Modify the Tversky loss
1 parent 58a0a8f commit c4cbd1e

File tree

1 file changed

+9
-14
lines changed

1 file changed

+9
-14
lines changed

segmentation_models_pytorch/losses/_functional.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Optional
55

66
import torch
7+
import torch.linalg as LA
78
import torch.nn.functional as F
89

910
__all__ = [
@@ -190,20 +191,14 @@ def soft_tversky_score(
190191
191192
"""
192193
assert output.size() == target.size()
193-
if dims is not None:
194-
difference = torch.norm(output - target, p=1, dim=dims)
195-
output_sum = torch.sum(output, dim=dims)
196-
target_sum = torch.sum(target, dim=dims)
197-
intersection = (output_sum + target_sum - difference) / 2 # TP
198-
fp = output_sum - intersection
199-
fn = target_sum - intersection
200-
else:
201-
difference = torch.norm(output - target, p=1)
202-
output_sum = torch.sum(output)
203-
target_sum = torch.sum(target)
204-
intersection = (output_sum + target_sum - difference) / 2 # TP
205-
fp = output_sum - intersection
206-
fn = target_sum - intersection
194+
195+
output_sum = torch.sum(output, dim=dims)
196+
target_sum = torch.sum(target, dim=dims)
197+
difference = LA.vector_norm(output - target, ord=1, dim=dims)
198+
199+
intersection = (output_sum + target_sum - difference) / 2 # TP
200+
fp = output_sum - intersection
201+
fn = target_sum - intersection
207202

208203
tversky_score = (intersection + smooth) / (
209204
intersection + alpha * fp + beta * fn + smooth

0 commit comments

Comments
 (0)