@@ -1650,7 +1650,10 @@ def __init__(self,
16501650 teacher ,
16511651 temperature = None ,
16521652 fraction_soft = None ,
1653- distill_start_step = 0 ,
1653+ mse_coeff = 0. ,
1654+ kl_coeff = 0. ,
1655+ cosine_coeff = 0. ,
1656+ distill_start_steps = 0 ,
16541657 teacher_checkpoint = None ,
16551658 initialize_student_weights = False ):
16561659 """Create a StudentTeacher.
@@ -1664,7 +1667,10 @@ def __init__(self,
16641667 target cross entropy to the training loss. The rest of the loss will be
16651668 the cross entropy with the one-hot actual label. Required only when
16661669 training.
1667- distill_start_step: an int, training steps after which teacher loss is
1670+ mse_coeff: MSE distillation loss co-efficient.
1671+ kl_coeff: KL-Divergence distillation loss co-efficient.
1672+ cosine_coeff: COsine-embedding distillation loss co-efficient.
1673+ distill_start_steps: an int, training steps after which teacher loss is
16681674 incorporated in the overall loss.
16691675 teacher_checkpoint: a string, the path to the teacher checkpoint that we
16701676 wish to use. Required only when training.
@@ -1676,9 +1682,15 @@ def __init__(self,
16761682 self .teacher = teacher
16771683 self .temperature = temperature
16781684 self .fraction_soft = fraction_soft
1679- self .distill_start_step = distill_start_step
1685+ self .distill_start_steps = distill_start_steps
16801686 self .teacher_checkpoint = teacher_checkpoint
16811687 self .initialize_student_weights = initialize_student_weights
1688+ self .kl_coeff = kl_coeff
1689+ self .cosine_coeff = cosine_coeff
1690+ self .mse_coeff = mse_coeff
1691+ if (fraction_soft + kl_coeff + cosine_coeff + mse_coeff ) > 1. :
1692+ raise ValueError ("Distillation co-efficients must not add up to a value "
1693+ "greater than 1." )
16821694
16831695 def call_simple (self ,
16841696 inputs ,
@@ -1751,15 +1763,40 @@ def call_simple(self,
17511763 weights = mtf .cast (mtf .greater (targets , 0 ), soft_loss .dtype )
17521764 soft_loss = (mtf .reduce_sum (soft_loss * weights ) /
17531765 self .student .loss_denominator (targets , num_microbatches ))
1766+ if self .kl_coeff > 0. :
1767+ student_pred = mtf .softmax (student_logits / self .temperature ,
1768+ output_vocab_dim )
1769+ kl_loss = mtf .layers .kl_divergence (
1770+ mtf .stop_gradient (soft_targets ), student_pred , output_vocab_dim ,
1771+ weights = weights )
1772+ else :
1773+ kl_loss = 0.
1774+ if self .cosine_coeff > 0. :
1775+ cosine_loss = mtf .layers .cosine_embedding_distill (
1776+ mtf .stop_gradient (teacher_logits ), student_logits , output_vocab_dim ,
1777+ weights = weights )
1778+ else :
1779+ cosine_loss = 0.
1780+ if self .mse_coeff > 0. :
1781+ mse_loss = mtf .layers .kl_divergence (
1782+ mtf .stop_gradient (teacher_logits ), student_logits , output_vocab_dim ,
1783+ weights = weights )
1784+ else :
1785+ mse_loss = 0.
17541786 global_step = tf .train .get_or_create_global_step ()
1755- current_fraction_soft = tf .cast (
1787+ distill_loss_fraction = (self .fraction_soft + self .kl_coeff +
1788+ self .mse_coeff + self .kl_coeff )
1789+ current_distill_fraction = tf .cast (
17561790 tf .cond (
1757- tf .math .greater (global_step , self .distill_start_step ),
1758- lambda : self . fraction_soft , lambda : tf .constant (0.0 )),
1791+ tf .math .greater (global_step , self .distill_start_steps ),
1792+ lambda : distill_loss_fraction , lambda : tf .constant (0.0 )),
17591793 dtype = tf .bfloat16 )
17601794
1761- loss = (1.0 - current_fraction_soft ) * hard_loss \
1762- + self .temperature ** 2 * current_fraction_soft * soft_loss
1795+ loss = (1.0 - current_distill_fraction ) * hard_loss \
1796+ + current_distill_fraction * (
1797+ self .temperature ** 2 * soft_loss * self .fraction_soft +
1798+ self .kl_coeff * kl_loss + self .mse_coeff + mse_loss +
1799+ self .cosine_coeff * cosine_loss )
17631800
17641801 return student_logits , loss
17651802
0 commit comments