Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 18 additions & 24 deletions segmentation_models_pytorch/losses/_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Optional

import torch
import torch.linalg as LA
import torch.nn.functional as F

__all__ = [
Expand Down Expand Up @@ -157,15 +158,7 @@ def soft_jaccard_score(
dims=None,
) -> torch.Tensor:
assert output.size() == target.size()
if dims is not None:
intersection = torch.sum(output * target, dim=dims)
cardinality = torch.sum(output + target, dim=dims)
else:
intersection = torch.sum(output * target)
cardinality = torch.sum(output + target)

union = cardinality - intersection
jaccard_score = (intersection + smooth) / (union + smooth).clamp_min(eps)
jaccard_score = soft_tversky_score(output, target, 1.0, 1.0, smooth, eps, dims)
return jaccard_score


Expand All @@ -177,13 +170,7 @@ def soft_dice_score(
dims=None,
) -> torch.Tensor:
assert output.size() == target.size()
if dims is not None:
intersection = torch.sum(output * target, dim=dims)
cardinality = torch.sum(output + target, dim=dims)
else:
intersection = torch.sum(output * target)
cardinality = torch.sum(output + target)
dice_score = (2.0 * intersection + smooth) / (cardinality + smooth).clamp_min(eps)
dice_score = soft_tversky_score(output, target, 0.5, 0.5, smooth, eps, dims)
return dice_score


Expand All @@ -196,15 +183,22 @@ def soft_tversky_score(
eps: float = 1e-7,
dims=None,
) -> torch.Tensor:
"""Tversky loss

References:
https://arxiv.org/pdf/2302.05666
https://arxiv.org/pdf/2303.16296

"""
assert output.size() == target.size()
if dims is not None:
intersection = torch.sum(output * target, dim=dims) # TP
fp = torch.sum(output * (1.0 - target), dim=dims)
fn = torch.sum((1 - output) * target, dim=dims)
else:
intersection = torch.sum(output * target) # TP
fp = torch.sum(output * (1.0 - target))
fn = torch.sum((1 - output) * target)

output_sum = torch.sum(output, dim=dims)
target_sum = torch.sum(target, dim=dims)
difference = LA.vector_norm(output - target, ord=1, dim=dims)

intersection = (output_sum + target_sum - difference) / 2 # TP
fp = output_sum - intersection
fn = target_sum - intersection

tversky_score = (intersection + smooth) / (
intersection + alpha * fp + beta * fn + smooth
Expand Down