Skip to content

Commit e0e99c5

Browse files
fmassaoke-aditya
andauthored
adds docs for focal loss (#2979) (#3008)
Co-authored-by: Aditya Oke <[email protected]>
1 parent 8db65b1 commit e0e99c5

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

docs/source/ops.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ torchvision.ops
2222
.. autofunction:: roi_pool
2323
.. autofunction:: ps_roi_pool
2424
.. autofunction:: deform_conv2d
25+
.. autofunction:: sigmoid_focal_loss
2526

2627
.. autoclass:: RoIAlign
2728
.. autoclass:: PSRoIAlign

torchvision/ops/focal_loss.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,21 @@
33

44

55
def sigmoid_focal_loss(
6-
inputs,
7-
targets,
6+
inputs: torch.Tensor,
7+
targets: torch.Tensor,
88
alpha: float = 0.25,
99
gamma: float = 2,
1010
reduction: str = "none",
1111
):
1212
"""
1313
Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py .
1414
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
15-
Args:
15+
16+
Arguments:
1617
inputs: A float tensor of arbitrary shape.
1718
The predictions for each example.
1819
targets: A float tensor with the same shape as inputs. Stores the binary
19-
classification label for each element in inputs
20+
classification label for each element in inputs
2021
(0 for the negative class and 1 for the positive class).
2122
alpha: (optional) Weighting factor in range (0,1) to balance
2223
positive vs negative examples or -1 for ignore. Default = 0.25

0 commit comments

Comments
 (0)