File tree Expand file tree Collapse file tree 1 file changed +4
-1
lines changed Expand file tree Collapse file tree 1 file changed +4
-1
lines changed Original file line number Diff line number Diff line change @@ -20,7 +20,7 @@ def sigmoid_focal_loss(
2020 targets (Tensor): A float tensor with the same shape as inputs. Stores the binary
2121 classification label for each element in inputs
2222 (0 for the negative class and 1 for the positive class).
23- alpha (float): Weighting factor in range (0,1) to balance
23+ alpha (float): Weighting factor in range [0, 1] to balance
2424 positive vs negative examples or -1 for ignore. Default: ``0.25``.
2525 gamma (float): Exponent of the modulating factor (1 - p_t) to
2626 balance easy vs hard examples. Default: ``2``.
@@ -33,6 +33,9 @@ def sigmoid_focal_loss(
3333 """
3434 # Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py
3535
36+ if not (0 <= alpha <= 1 ) or alpha != - 1 :
37+ raise ValueError (f"Invalid alpha value: { alpha } . alpha must be in the range [0,1] or -1 for ignore." )
38+
3639 if not torch .jit .is_scripting () and not torch .jit .is_tracing ():
3740 _log_api_usage_once (sigmoid_focal_loss )
3841 p = torch .sigmoid (inputs )
You can’t perform that action at this time.
0 commit comments