@@ -30,8 +30,8 @@ class LabelSmoothedCrossEntropyLoss(nn.Module):
3030 reduction (str): reduction method [sum, mean] (default: sum)
3131 architecture (str): speech model`s model [las, transformer] (default: las)
3232
33- Inputs: logit , target
34- logit (torch.Tensor): probability distribution value from model and it has a logarithm shape
33+ Inputs: logits , target
34+ logits (torch.Tensor): probability distribution value from model and it has a logarithm shape
3535 target (torch.Tensor): ground-thruth encoded to integers which directly point a word in label
3636
3737 Returns: label_smoothed
@@ -44,7 +44,7 @@ def __init__(
4444 smoothing : float = 0.1 , # ratio of smoothing (confidence = 1.0 - smoothing)
4545 dim : int = - 1 , # dimension of caculation loss
4646 reduction = 'sum' , # reduction method [sum, mean]
47- architecture = 'las' # speech model`s model [las, transformer]
47+ architecture = 'las' , # speech model`s model [las, transformer]
4848 ) -> None :
4949 super (LabelSmoothedCrossEntropyLoss , self ).__init__ ()
5050 self .confidence = 1.0 - smoothing
@@ -62,16 +62,16 @@ def __init__(
6262 else :
6363 raise ValueError ("Unsupported reduction method {0}" .format (reduction ))
6464
65- def forward (self , logit : Tensor , target : Tensor ):
65+ def forward (self , logits : Tensor , targets : Tensor ):
6666 if self .architecture == 'transformer' :
67- logit = F .log_softmax (logit , dim = - 1 )
67+ logits = F .log_softmax (logits , dim = - 1 )
6868
6969 if self .smoothing > 0.0 :
7070 with torch .no_grad ():
71- label_smoothed = torch .zeros_like (logit )
71+ label_smoothed = torch .zeros_like (logits )
7272 label_smoothed .fill_ (self .smoothing / (self .num_classes - 1 ))
73- label_smoothed .scatter_ (1 , target .data .unsqueeze (1 ), self .confidence )
74- label_smoothed [target == self .ignore_index , :] = 0
75- return self .reduction_method (- label_smoothed * logit )
73+ label_smoothed .scatter_ (1 , targets .data .unsqueeze (1 ), self .confidence )
74+ label_smoothed [targets == self .ignore_index , :] = 0
75+ return self .reduction_method (- label_smoothed * logits )
7676
77- return F .cross_entropy (logit , target , ignore_index = self .ignore_index , reduction = self .reduction )
77+ return F .cross_entropy (logits , targets , ignore_index = self .ignore_index , reduction = self .reduction )
0 commit comments