Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions segmentation_models_pytorch/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .soft_ce import SoftCrossEntropyLoss
from .tversky import TverskyLoss
from .mcc import MCCLoss
from .focal_tversky import FocalTverskyLoss

__all__ = [
"BINARY_MODE",
Expand All @@ -21,4 +22,5 @@
"SoftCrossEntropyLoss",
"TverskyLoss",
"MCCLoss",
"FocalTverskyLoss",
]
31 changes: 31 additions & 0 deletions segmentation_models_pytorch/losses/_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,37 @@ def soft_tversky_score(
).clamp_min(eps)
return tversky_score

def focal_tversky_loss(
output: torch.Tensor,
target: torch.Tensor,
alpha: float,
beta: float,
gamma: float,
smooth: float = 0.0,
eps: float = 1e-7,
dims=None,
) -> torch.Tensor:
"""Focal Tversky loss

References:
https://arxiv.org/pdf/1810.07842

"""
assert output.size() == target.size()

output_sum = torch.sum(output, dim=dims)
target_sum = torch.sum(target, dim=dims)
difference = LA.vector_norm(output - target, ord=1, dim=dims)

intersection = (output_sum + target_sum - difference) / 2 # TP
fp = output_sum - intersection
fn = target_sum - intersection

tversky_score = (intersection + smooth) / (
intersection + alpha * fp + beta * fn + smooth
).clamp_min(eps)
Comment on lines +226 to +236
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I understand we can just call soft_tversky_score here?

tversky_loss = (1 - tversky_score) ** gamma
return tversky_loss

def wing_loss(
output: torch.Tensor, target: torch.Tensor, width=5, curvature=0.5, reduction="mean"
Expand Down
69 changes: 69 additions & 0 deletions segmentation_models_pytorch/losses/focal_tversky.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from typing import List, Optional

import torch
from ._functional import focal_tversky_loss
from .constants import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE
from .dice import DiceLoss

__all__ = ["TverskyLoss"]

Check failure on line 8 in segmentation_models_pytorch/losses/focal_tversky.py

View workflow job for this annotation

GitHub Actions / style

Ruff (F822)

segmentation_models_pytorch/losses/focal_tversky.py:8:12: F822 Undefined name `TverskyLoss` in `__all__`


class FocalTverskyLoss(DiceLoss):
"""Focal Tversky loss for image segmentation tasks.
FP and FN are weighted by alpha and beta parameters, respectively.
With alpha == beta == 0.5, this loss becomes equivalent to DiceLoss.
The gamma parameter focuses the loss on hard-to-classify examples.
If gamma > 1, the function focuses more on misclassified examples.
If gamma = 1, it is equivalent to Tversky Loss.
This loss supports binary, multiclass, and multilabel cases.

Args:
mode: Metric mode {'binary', 'multiclass', 'multilabel'}
classes: Optional list of classes that contribute to loss computation;
By default, all channels are included.
log_loss: If True, computes loss as ``-log(tversky)``; otherwise as ``1 - tversky``
from_logits: If True, assumes input is raw logits
smooth: Smoothing factor to avoid division by zero
ignore_index: Label indicating ignored pixels (not contributing to loss)
eps: Small epsilon for numerical stability
alpha: Weight constant that penalizes False Positives (FPs)
beta: Weight constant that penalizes False Negatives (FNs)
gamma: Focal factor to adjust the focus on harder examples (defaults to 1.0)


Return:
loss: torch.Tensor

"""

def __init__(
self,
mode: str,
classes: List[int] = None,
log_loss: bool = False,
from_logits: bool = True,
smooth: float = 0.0,
ignore_index: Optional[int] = None,
eps: float = 1e-7,
alpha: float = 0.5,
beta: float = 0.5,
gamma: float = 1.0,
):
assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE}
Copy link
Collaborator

@qubvel qubvel Oct 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please move the class docstring under __init__ similar to other losses

super().__init__(
mode, classes, log_loss, from_logits, smooth, ignore_index, eps
)
self.alpha = alpha
self.beta = beta
self.gamma = gamma

def aggregate_loss(self, loss):
return loss.mean() ** self.gamma

def compute_score(
self, output, target, smooth=0.0, eps=1e-7, dims=None
) -> torch.Tensor:
tversky_score = focal_tversky_loss(
output, target, self.alpha, self.beta, self.gamma, smooth, eps, dims
)
return tversky_score
18 changes: 18 additions & 0 deletions tests/test_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
SoftCrossEntropyLoss,
TverskyLoss,
MCCLoss,
focal_tversky,
)


Expand Down Expand Up @@ -92,6 +93,23 @@ def test_soft_tversky_score(y_true, y_pred, expected, eps, alpha, beta):
actual = F.soft_tversky_score(y_pred, y_true, eps=eps, alpha=alpha, beta=beta)
assert float(actual) == pytest.approx(expected, eps)

@pytest.mark.parametrize(
["y_true", "y_pred", "expected", "eps", "alpha", "beta", "gamma"],
[
[[1, 1, 1, 1], [1, 1, 1, 1], 0.0, 1e-5, 0.5, 0.5, 1.0],
[[0, 1, 1, 0], [0, 1, 1, 0], 0.0, 1e-5, 0.5, 0.5, 1.0],
[[1, 1, 1, 1], [0, 0, 0, 0], 1.0, 1e-5, 0.5, 0.5, 1.0],
[[1, 0, 1, 0], [0, 1, 0, 0], 1.0, 1e-5, 0.5, 0.5, 1.0],
Comment on lines +98 to +102
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please add a test case with gamma != 1?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @qubvel, thanks a lot for your comments ! Since @zifuwanggg mentioned that the loss function is already implemented I'll close this pull request

[[1, 0, 1, 0], [1, 1, 0, 0], 0.5, 1e-5, 0.5, 0.5, 1.0],
],
)

def test_focal_tversky_score(y_true, y_pred, expected, eps, alpha, beta, gamma):
y_true = torch.tensor(y_true, dtype=torch.float32)
y_pred = torch.tensor(y_pred, dtype=torch.float32)
actual = F.focal_tversky_loss(y_pred, y_true, eps=eps, alpha=alpha, beta=beta, gamma=gamma)
assert float(actual) == pytest.approx(expected, eps)


@torch.no_grad()
def test_dice_loss_binary():
Expand Down
Loading