From 4c221b0fa568ce8a8ff862f83a943a4f8e8a2fc6 Mon Sep 17 00:00:00 2001 From: Aditi Mhatre Date: Sat, 5 Oct 2024 16:58:59 +0200 Subject: [PATCH 1/7] Created Focal Tversky file --- .../losses/focal_tversky.py | 68 +++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 segmentation_models_pytorch/losses/focal_tversky.py diff --git a/segmentation_models_pytorch/losses/focal_tversky.py b/segmentation_models_pytorch/losses/focal_tversky.py new file mode 100644 index 00000000..9d1fb870 --- /dev/null +++ b/segmentation_models_pytorch/losses/focal_tversky.py @@ -0,0 +1,68 @@ +from typing import List, Optional + +import torch +from ._functional import soft_tversky_score +from .constants import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE +from .dice import DiceLoss + +__all__ = ["TverskyLoss"] + + +class FocalTverskyLoss(DiceLoss): + """Focal Tversky loss for image segmentation tasks. + FP and FN are weighted by alpha and beta parameters, respectively. + With alpha == beta == 0.5, this loss becomes equivalent to DiceLoss. + The gamma parameter focuses the loss on hard-to-classify examples. + If gamma > 1, the function focuses more on misclassified examples. + If gamma = 1, it is equivalent to Tversky Loss. + This loss supports binary, multiclass, and multilabel cases. + + Args: + mode: Metric mode {'binary', 'multiclass', 'multilabel'} + classes: Optional list of classes that contribute to loss computation; + By default, all channels are included. + log_loss: If True, computes loss as ``-log(tversky)``; otherwise as ``1 - tversky`` + from_logits: If True, assumes input is raw logits + smooth: Smoothing factor to avoid division by zero + ignore_index: Label indicating ignored pixels (not contributing to loss) + eps: Small epsilon for numerical stability + alpha: Weight constant that penalizes False Positives (FPs) + beta: Weight constant that penalizes False Negatives (FNs) + gamma: Focal factor to adjust the focus on harder examples (defaults to 1.0) + + + Return: + loss: torch.Tensor + + """ + + def __init__( + self, + mode: str, + classes: List[int] = None, + log_loss: bool = False, + from_logits: bool = True, + smooth: float = 0.0, + ignore_index: Optional[int] = None, + eps: float = 1e-7, + alpha: float = 0.5, + beta: float = 0.5, + gamma: float = 1.0, + ): + assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE} + super().__init__( + mode, classes, log_loss, from_logits, smooth, ignore_index, eps + ) + self.alpha = alpha + self.beta = beta + self.gamma = gamma + + def aggregate_loss(self, loss): + return loss.mean() ** self.gamma + + def compute_score( + self, output, target, smooth=0.0, eps=1e-7, dims=None + ) -> torch.Tensor: + return soft_tversky_score( + output, target, self.alpha, self.beta, smooth, eps, dims + ) From 0fab35ab6ebcb677d35d0f603537596237b70e81 Mon Sep 17 00:00:00 2001 From: Aditi Mhatre Date: Sat, 5 Oct 2024 17:08:06 +0200 Subject: [PATCH 2/7] Added Focal Tversky Function --- segmentation_models_pytorch/losses/focal_tversky.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/segmentation_models_pytorch/losses/focal_tversky.py b/segmentation_models_pytorch/losses/focal_tversky.py index 9d1fb870..3a2f9a9a 100644 --- a/segmentation_models_pytorch/losses/focal_tversky.py +++ b/segmentation_models_pytorch/losses/focal_tversky.py @@ -63,6 +63,7 @@ def aggregate_loss(self, loss): def compute_score( self, output, target, smooth=0.0, eps=1e-7, dims=None ) -> torch.Tensor: - return soft_tversky_score( + tversky_score = soft_tversky_score( output, target, self.alpha, self.beta, smooth, eps, dims ) + return (1 - tversky_score) ** self.gamma \ No newline at end of file From 2c38b8019c41c33381f065a0053972efe7bff605 Mon Sep 17 00:00:00 2001 From: Aditi Mhatre Date: Sat, 5 Oct 2024 17:15:34 +0200 Subject: [PATCH 3/7] Initialized Test for Focal Tversky Loss --- tests/test_losses.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/test_losses.py b/tests/test_losses.py index 5c3ad75a..e8383506 100644 --- a/tests/test_losses.py +++ b/tests/test_losses.py @@ -92,6 +92,21 @@ def test_soft_tversky_score(y_true, y_pred, expected, eps, alpha, beta): actual = F.soft_tversky_score(y_pred, y_true, eps=eps, alpha=alpha, beta=beta) assert float(actual) == pytest.approx(expected, eps) +@pytest.mark.parametrize( + ["y_true", "y_pred", "expected", "eps", "alpha", "beta", "gamma"], + [ + [[1, 1, 1, 1], [1, 1, 1, 1], 1.0, 1e-5, 0.5, 0.5], + [[0, 1, 1, 0], [0, 1, 1, 0], 1.0, 1e-5, 0.5, 0.5], + [[1, 1, 1, 1], [1, 1, 0, 0], 2.0 / 3.0, 1e-5, 0.5, 0.5, 2.0], + ], +) + +def test_focal_tversky_score(y_true, y_pred, expected, eps, alpha, beta, gamma): + y_true = torch.tensor(y_true, dtype=torch.float32) + y_pred = torch.tensor(y_pred, dtype=torch.float32) + actual = F.soft_tversky_score(y_pred, y_true, eps=eps, alpha=alpha, beta=beta, gamma=gamma) + assert float(actual) == pytest.approx(expected, eps) + @torch.no_grad() def test_dice_loss_binary(): From 550260fa15e1eef3e8ef8e62b38b9f32804adc73 Mon Sep 17 00:00:00 2001 From: Aditi Mhatre Date: Sat, 5 Oct 2024 21:59:53 +0200 Subject: [PATCH 4/7] Added test for focal tversky loss --- tests/test_losses.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/test_losses.py b/tests/test_losses.py index e8383506..51a72226 100644 --- a/tests/test_losses.py +++ b/tests/test_losses.py @@ -9,6 +9,7 @@ SoftCrossEntropyLoss, TverskyLoss, MCCLoss, + Focal_Tversky, ) @@ -95,9 +96,14 @@ def test_soft_tversky_score(y_true, y_pred, expected, eps, alpha, beta): @pytest.mark.parametrize( ["y_true", "y_pred", "expected", "eps", "alpha", "beta", "gamma"], [ - [[1, 1, 1, 1], [1, 1, 1, 1], 1.0, 1e-5, 0.5, 0.5], - [[0, 1, 1, 0], [0, 1, 1, 0], 1.0, 1e-5, 0.5, 0.5], - [[1, 1, 1, 1], [1, 1, 0, 0], 2.0 / 3.0, 1e-5, 0.5, 0.5, 2.0], + # [[1, 1, 1, 1], [1, 1, 1, 1], 1.0, 1e-5, 0.5, 0.5, 1.0], + # [[0, 1, 1, 0], [0, 1, 1, 0], 1.0, 1e-5, 0.5, 0.5, 2.0], + # [[1, 1, 1, 1], [1, 1, 0, 0], 2.0 / 3.0, 1e-5, 0.5, 0.5, 3.0], + [[1, 1, 1, 1], [1, 1, 1, 1], 0.0, 1e-5, 0.5, 0.5, 1.0], + [[0, 1, 1, 0], [0, 1, 1, 0], 0.0, 1e-5, 0.5, 0.5, 1.0], + [[1, 1, 1, 1], [0, 0, 0, 0], 1.0, 1e-5, 0.5, 0.5, 1.0], + [[1, 0, 1, 0], [0, 1, 0, 0], 1.0, 1e-5, 0.5, 0.5, 1.0], + [[1, 0, 1, 0], [1, 1, 0, 0], 0.5, 1e-5, 0.5, 0.5, 1.0], ], ) From 6e77fb185521d8e14dcb8f47825ef301593bac89 Mon Sep 17 00:00:00 2001 From: Aditi Mhatre Date: Sat, 5 Oct 2024 22:09:01 +0200 Subject: [PATCH 5/7] Added test for focal tversky loss --- tests/test_losses.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/test_losses.py b/tests/test_losses.py index 51a72226..fa37181d 100644 --- a/tests/test_losses.py +++ b/tests/test_losses.py @@ -9,7 +9,7 @@ SoftCrossEntropyLoss, TverskyLoss, MCCLoss, - Focal_Tversky, + focal_tversky, ) @@ -96,9 +96,6 @@ def test_soft_tversky_score(y_true, y_pred, expected, eps, alpha, beta): @pytest.mark.parametrize( ["y_true", "y_pred", "expected", "eps", "alpha", "beta", "gamma"], [ - # [[1, 1, 1, 1], [1, 1, 1, 1], 1.0, 1e-5, 0.5, 0.5, 1.0], - # [[0, 1, 1, 0], [0, 1, 1, 0], 1.0, 1e-5, 0.5, 0.5, 2.0], - # [[1, 1, 1, 1], [1, 1, 0, 0], 2.0 / 3.0, 1e-5, 0.5, 0.5, 3.0], [[1, 1, 1, 1], [1, 1, 1, 1], 0.0, 1e-5, 0.5, 0.5, 1.0], [[0, 1, 1, 0], [0, 1, 1, 0], 0.0, 1e-5, 0.5, 0.5, 1.0], [[1, 1, 1, 1], [0, 0, 0, 0], 1.0, 1e-5, 0.5, 0.5, 1.0], @@ -110,7 +107,7 @@ def test_soft_tversky_score(y_true, y_pred, expected, eps, alpha, beta): def test_focal_tversky_score(y_true, y_pred, expected, eps, alpha, beta, gamma): y_true = torch.tensor(y_true, dtype=torch.float32) y_pred = torch.tensor(y_pred, dtype=torch.float32) - actual = F.soft_tversky_score(y_pred, y_true, eps=eps, alpha=alpha, beta=beta, gamma=gamma) + actual = F.focal_tversky_loss(y_pred, y_true, eps=eps, alpha=alpha, beta=beta, gamma=gamma) assert float(actual) == pytest.approx(expected, eps) From 81478535b60384ff91e3682748158ccbcf7509e6 Mon Sep 17 00:00:00 2001 From: Aditi Mhatre Date: Sat, 5 Oct 2024 22:10:20 +0200 Subject: [PATCH 6/7] Added Focal Tversky loss --- .../losses/_functional.py | 31 +++++++++++++++++++ .../losses/focal_tversky.py | 8 ++--- 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/segmentation_models_pytorch/losses/_functional.py b/segmentation_models_pytorch/losses/_functional.py index a26f3f48..62c4faba 100644 --- a/segmentation_models_pytorch/losses/_functional.py +++ b/segmentation_models_pytorch/losses/_functional.py @@ -205,6 +205,37 @@ def soft_tversky_score( ).clamp_min(eps) return tversky_score +def focal_tversky_loss( + output: torch.Tensor, + target: torch.Tensor, + alpha: float, + beta: float, + gamma: float, + smooth: float = 0.0, + eps: float = 1e-7, + dims=None, +) -> torch.Tensor: + """Focal Tversky loss + + References: + https://arxiv.org/pdf/1810.07842 + + """ + assert output.size() == target.size() + + output_sum = torch.sum(output, dim=dims) + target_sum = torch.sum(target, dim=dims) + difference = LA.vector_norm(output - target, ord=1, dim=dims) + + intersection = (output_sum + target_sum - difference) / 2 # TP + fp = output_sum - intersection + fn = target_sum - intersection + + tversky_score = (intersection + smooth) / ( + intersection + alpha * fp + beta * fn + smooth + ).clamp_min(eps) + tversky_loss = (1 - tversky_score) ** gamma + return tversky_loss def wing_loss( output: torch.Tensor, target: torch.Tensor, width=5, curvature=0.5, reduction="mean" diff --git a/segmentation_models_pytorch/losses/focal_tversky.py b/segmentation_models_pytorch/losses/focal_tversky.py index 3a2f9a9a..dd0659f4 100644 --- a/segmentation_models_pytorch/losses/focal_tversky.py +++ b/segmentation_models_pytorch/losses/focal_tversky.py @@ -1,7 +1,7 @@ from typing import List, Optional import torch -from ._functional import soft_tversky_score +from ._functional import focal_tversky_loss from .constants import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE from .dice import DiceLoss @@ -63,7 +63,7 @@ def aggregate_loss(self, loss): def compute_score( self, output, target, smooth=0.0, eps=1e-7, dims=None ) -> torch.Tensor: - tversky_score = soft_tversky_score( - output, target, self.alpha, self.beta, smooth, eps, dims + tversky_score = focal_tversky_loss( + output, target, self.alpha, self.beta, self.gamma, smooth, eps, dims ) - return (1 - tversky_score) ** self.gamma \ No newline at end of file + return tversky_score \ No newline at end of file From cf7164621fe783beae9edca0ae0fbc311af89877 Mon Sep 17 00:00:00 2001 From: Aditi Mhatre Date: Sat, 5 Oct 2024 22:14:27 +0200 Subject: [PATCH 7/7] Initialized Focal Tversky loss --- segmentation_models_pytorch/losses/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/segmentation_models_pytorch/losses/__init__.py b/segmentation_models_pytorch/losses/__init__.py index 10b69c83..a9c57fa2 100644 --- a/segmentation_models_pytorch/losses/__init__.py +++ b/segmentation_models_pytorch/losses/__init__.py @@ -8,6 +8,7 @@ from .soft_ce import SoftCrossEntropyLoss from .tversky import TverskyLoss from .mcc import MCCLoss +from .focal_tversky import FocalTverskyLoss __all__ = [ "BINARY_MODE", @@ -21,4 +22,5 @@ "SoftCrossEntropyLoss", "TverskyLoss", "MCCLoss", + "FocalTverskyLoss", ]