@@ -1643,6 +1643,7 @@ def __init__(self,
16431643 teacher ,
16441644 temperature = None ,
16451645 fraction_soft = None ,
1646+ distill_start_step = 0 ,
16461647 teacher_checkpoint = None ,
16471648 initialize_student_weights = False ):
16481649 """Create a StudentTeacher.
@@ -1656,6 +1657,8 @@ def __init__(self,
16561657 target cross entropy to the training loss. The rest of the loss will be
16571658 the cross entropy with the one-hot actual label. Required only when
16581659 training.
1660+ distill_start_step: an int, training steps after which teacher loss is
1661+ incorporated in the overall loss.
16591662 teacher_checkpoint: a string, the path to the teacher checkpoint that we
16601663 wish to use. Required only when training.
16611664 initialize_student_weights: a boolean, if true then initialize any
@@ -1666,6 +1669,7 @@ def __init__(self,
16661669 self .teacher = teacher
16671670 self .temperature = temperature
16681671 self .fraction_soft = fraction_soft
1672+ self .distill_start_step = distill_start_step
16691673 self .teacher_checkpoint = teacher_checkpoint
16701674 self .initialize_student_weights = initialize_student_weights
16711675
@@ -1740,9 +1744,15 @@ def call_simple(self,
17401744 weights = mtf .cast (mtf .greater (targets , 0 ), soft_loss .dtype )
17411745 soft_loss = (mtf .reduce_sum (soft_loss * weights ) /
17421746 self .student .loss_denominator (targets , num_microbatches ))
1743-
1744- loss = (1.0 - self .fraction_soft ) * hard_loss \
1745- + self .temperature ** 2 * self .fraction_soft * soft_loss
1747+ global_step = tf .train .get_or_create_global_step ()
1748+ current_fraction_soft = tf .cast (
1749+ tf .cond (
1750+ tf .math .greater (global_step , self .distill_start_step ),
1751+ lambda : self .fraction_soft , lambda : tf .constant (0.0 )),
1752+ dtype = tf .bfloat16 )
1753+
1754+ loss = (1.0 - current_fraction_soft ) * hard_loss \
1755+ + self .temperature ** 2 * current_fraction_soft * soft_loss
17461756
17471757 return student_logits , loss
17481758
0 commit comments