@@ -1101,6 +1101,105 @@ def softmax_cross_entropy_with_logits(logits, targets, vocab_dim, z_loss=0.0):
11011101 return loss
11021102
11031103
1104+ def kl_divergence (y_true , y_pred , reduced_dim , weights = None , epsilon = 1e-6 ):
1105+ """Kullback-Leibler-Divergence between `y_true` and `y_pred`.
1106+
1107+ Computes: `loss = y_true * log(y_true / y_pred)`
1108+ From: tf.keras.losses.KLDivergence (Custom implementation with mtf)
1109+ See: https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence
1110+
1111+ Args:
1112+ y_true: mtf.Tensor, target predictions (distribution).
1113+ y_pred: mtf.Tensor, actual predictions (distribution).
1114+ reduced_dim: mtf.Dimension, reduction dimension for sum.
1115+ weights: Optional mtf.Tensor, indicator for padded regions.
1116+ epsilon: float, minimum value for numerical stability.
1117+ Returns:
1118+ scalar: K-L Divergence loss.
1119+ Raises:
1120+ ValueError: if the shapes do not match or reduced_dim is not valid.
1121+ """
1122+ if set (y_true .shape .dims ) != set (y_pred .shape .dims ):
1123+ raise ValueError (
1124+ "`y_true` and `y_pred` must be of the same shape. "
1125+ f"Currently they are { y_true .shape .dims } and { y_pred .shape .dims } " )
1126+ if reduced_dim not in y_true .shape .dims :
1127+ raise ValueError (
1128+ f"`reduced_dim` must be a valid dimension (from { y_true .shape .dims } )." )
1129+ if weights is None :
1130+ weights = 1.
1131+
1132+ def _clip (x , min_value , max_value ):
1133+ # Clip values for numerical stability.
1134+ x = mtf .maximum (x , min_value )
1135+ x = mtf .minimum (x , max_value )
1136+ return x
1137+
1138+ y_true = _clip (y_true , epsilon , 1. )
1139+ y_pred = _clip (y_pred , epsilon , 1. )
1140+ return mtf .reduce_sum (weights * y_true * mtf .log (y_true / y_pred ))
1141+
1142+
1143+ def mean_squared_error (y_true , y_pred , weights = None ):
1144+ """L2-Loss between `y_true` and `y_pred`.
1145+
1146+ Args:
1147+ y_true: mtf.Tensor, target logits.
1148+ y_pred: mtf.Tensor, actual logits.
1149+ weights: Optional mtf.Tensor, indicator for padded regions.
1150+ Returns:
1151+ scalar: L2 loss.
1152+ Raises:
1153+ ValueError: if the shapes do not match or reduced_dim is not valid.
1154+ """
1155+ if set (y_true .shape .dims ) != set (y_pred .shape .dims ):
1156+ raise ValueError (
1157+ "`y_true` and `y_pred` must be of the same shape. "
1158+ f"Currently they are { y_true .shape .dims } and { y_pred .shape .dims } " )
1159+ if weights is None :
1160+ weights = 1.
1161+ return mtf .reduce_sum (weights * mtf .square (y_true - y_pred ))
1162+
1163+
1164+ def cosine_embedding_distill (y_true , y_pred , reduced_dim , weights = None ,
1165+ epsilon = 1e-6 ):
1166+ """Cosine embedding loss for distillation from teacher to student logits.
1167+
1168+ See: https://arxiv.org/abs/1910.01108 (DistilBert) and
1169+ https://github.com/huggingface/transformers/tree/master/examples/
1170+ research_projects/distillation.
1171+
1172+ Args:
1173+ y_true: mtf.Tensor, teacher logits.
1174+ y_pred: mtf.Tensor, student logits.
1175+ reduced_dim: mtf.Dimension, reduction dimension for sum.
1176+ weights: Optional mtf.Tensor, indicator for padded regions.
1177+ epsilon: float, for numerical stability.
1178+ Returns:
1179+ scalar: mean cosine embedding distance.
1180+ Raises:
1181+ ValueError: if the shapes do not match or reduced_dim is not valid.
1182+ """
1183+ if set (y_true .shape .dims ) != set (y_pred .shape .dims ):
1184+ raise ValueError (
1185+ "`y_true` and `y_pred` must be of the same shape. "
1186+ f"Currently they are { y_true .shape .dims } and { y_pred .shape .dims } " )
1187+ if reduced_dim not in y_true .shape .dims :
1188+ raise ValueError (
1189+ f"`reduced_dim` must be a valid dimension (from { y_true .shape .dims } )." )
1190+ if weights is None :
1191+ weights = 1.
1192+
1193+ prod_sum = mtf .reduce_sum (y_true * y_pred , reduced_dim = reduced_dim )
1194+ y_true_sq_sum = mtf .reduce_sum (y_true * y_true , reduced_dim = reduced_dim )
1195+ y_pred_sq_sum = mtf .reduce_sum (y_pred * y_pred , reduced_dim = reduced_dim )
1196+ inv_denom = mtf .rsqrt (y_true_sq_sum * y_pred_sq_sum + epsilon )
1197+ cos = prod_sum * inv_denom
1198+ # TODO(vinaysrao): Turn this into a more general cosine embedding loss with
1199+ # a `targets` tensor.
1200+ return mtf .reduce_sum (weights * (1. - cos ))
1201+
1202+
11041203def sigmoid_cross_entropy_with_logits (logits , targets ):
11051204 """Sigmoid cross-entropy loss.
11061205
0 commit comments