-
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Added Focal tversky loss #932
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4c221b0
0fab35a
2c38b80
550260f
6e77fb1
8147853
cf71646
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
from typing import List, Optional | ||
|
||
import torch | ||
from ._functional import focal_tversky_loss | ||
from .constants import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE | ||
from .dice import DiceLoss | ||
|
||
__all__ = ["TverskyLoss"] | ||
|
||
|
||
class FocalTverskyLoss(DiceLoss): | ||
"""Focal Tversky loss for image segmentation tasks. | ||
FP and FN are weighted by alpha and beta parameters, respectively. | ||
With alpha == beta == 0.5, this loss becomes equivalent to DiceLoss. | ||
The gamma parameter focuses the loss on hard-to-classify examples. | ||
If gamma > 1, the function focuses more on misclassified examples. | ||
If gamma = 1, it is equivalent to Tversky Loss. | ||
This loss supports binary, multiclass, and multilabel cases. | ||
|
||
Args: | ||
mode: Metric mode {'binary', 'multiclass', 'multilabel'} | ||
classes: Optional list of classes that contribute to loss computation; | ||
By default, all channels are included. | ||
log_loss: If True, computes loss as ``-log(tversky)``; otherwise as ``1 - tversky`` | ||
from_logits: If True, assumes input is raw logits | ||
smooth: Smoothing factor to avoid division by zero | ||
ignore_index: Label indicating ignored pixels (not contributing to loss) | ||
eps: Small epsilon for numerical stability | ||
alpha: Weight constant that penalizes False Positives (FPs) | ||
beta: Weight constant that penalizes False Negatives (FNs) | ||
gamma: Focal factor to adjust the focus on harder examples (defaults to 1.0) | ||
|
||
|
||
Return: | ||
loss: torch.Tensor | ||
|
||
""" | ||
|
||
def __init__( | ||
self, | ||
mode: str, | ||
classes: List[int] = None, | ||
log_loss: bool = False, | ||
from_logits: bool = True, | ||
smooth: float = 0.0, | ||
ignore_index: Optional[int] = None, | ||
eps: float = 1e-7, | ||
alpha: float = 0.5, | ||
beta: float = 0.5, | ||
gamma: float = 1.0, | ||
): | ||
assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please move the class docstring under |
||
super().__init__( | ||
mode, classes, log_loss, from_logits, smooth, ignore_index, eps | ||
) | ||
self.alpha = alpha | ||
self.beta = beta | ||
self.gamma = gamma | ||
|
||
def aggregate_loss(self, loss): | ||
return loss.mean() ** self.gamma | ||
|
||
def compute_score( | ||
self, output, target, smooth=0.0, eps=1e-7, dims=None | ||
) -> torch.Tensor: | ||
tversky_score = focal_tversky_loss( | ||
output, target, self.alpha, self.beta, self.gamma, smooth, eps, dims | ||
) | ||
return tversky_score |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,7 @@ | |
SoftCrossEntropyLoss, | ||
TverskyLoss, | ||
MCCLoss, | ||
focal_tversky, | ||
) | ||
|
||
|
||
|
@@ -92,6 +93,23 @@ def test_soft_tversky_score(y_true, y_pred, expected, eps, alpha, beta): | |
actual = F.soft_tversky_score(y_pred, y_true, eps=eps, alpha=alpha, beta=beta) | ||
assert float(actual) == pytest.approx(expected, eps) | ||
|
||
@pytest.mark.parametrize( | ||
["y_true", "y_pred", "expected", "eps", "alpha", "beta", "gamma"], | ||
[ | ||
[[1, 1, 1, 1], [1, 1, 1, 1], 0.0, 1e-5, 0.5, 0.5, 1.0], | ||
[[0, 1, 1, 0], [0, 1, 1, 0], 0.0, 1e-5, 0.5, 0.5, 1.0], | ||
[[1, 1, 1, 1], [0, 0, 0, 0], 1.0, 1e-5, 0.5, 0.5, 1.0], | ||
[[1, 0, 1, 0], [0, 1, 0, 0], 1.0, 1e-5, 0.5, 0.5, 1.0], | ||
Comment on lines
+98
to
+102
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you please add a test case with gamma != 1? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @qubvel, thanks a lot for your comments ! Since @zifuwanggg mentioned that the loss function is already implemented I'll close this pull request |
||
[[1, 0, 1, 0], [1, 1, 0, 0], 0.5, 1e-5, 0.5, 0.5, 1.0], | ||
], | ||
) | ||
|
||
def test_focal_tversky_score(y_true, y_pred, expected, eps, alpha, beta, gamma): | ||
y_true = torch.tensor(y_true, dtype=torch.float32) | ||
y_pred = torch.tensor(y_pred, dtype=torch.float32) | ||
actual = F.focal_tversky_loss(y_pred, y_true, eps=eps, alpha=alpha, beta=beta, gamma=gamma) | ||
assert float(actual) == pytest.approx(expected, eps) | ||
|
||
|
||
@torch.no_grad() | ||
def test_dice_loss_binary(): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I understand we can just call
soft_tversky_score
here?