Skip to content

Commit 66615c0

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] feat: Raise ValueError for alpha > 1 in sigmoid_focal_loss (#8882)
Reviewed By: scotts Differential Revision: D77997070 fbshipit-source-id: be9289a4b3fab0447d17ae18ceeddd617cc9ab75 Co-authored-by: Nicolas Hug <[email protected]> Co-authored-by: Nicolas Hug <[email protected]>
1 parent d1fb549 commit 66615c0

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

torchvision/ops/focal_loss.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def sigmoid_focal_loss(
2020
targets (Tensor): A float tensor with the same shape as inputs. Stores the binary
2121
classification label for each element in inputs
2222
(0 for the negative class and 1 for the positive class).
23-
alpha (float): Weighting factor in range (0,1) to balance
23+
alpha (float): Weighting factor in range [0, 1] to balance
2424
positive vs negative examples or -1 for ignore. Default: ``0.25``.
2525
gamma (float): Exponent of the modulating factor (1 - p_t) to
2626
balance easy vs hard examples. Default: ``2``.
@@ -33,6 +33,9 @@ def sigmoid_focal_loss(
3333
"""
3434
# Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py
3535

36+
if not (0 <= alpha <= 1) or alpha != -1:
37+
raise ValueError(f"Invalid alpha value: {alpha}. alpha must be in the range [0,1] or -1 for ignore.")
38+
3639
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
3740
_log_api_usage_once(sigmoid_focal_loss)
3841
p = torch.sigmoid(inputs)

0 commit comments

Comments
 (0)