Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit 2c1a4c8

Browse files
author
Mesh TensorFlow Team
committed
Allow distillation to start later in training, after the student model is pretrained.
PiperOrigin-RevId: 354539434
1 parent 759b788 commit 2c1a4c8

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

mesh_tensorflow/transformer/transformer.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1643,6 +1643,7 @@ def __init__(self,
16431643
teacher,
16441644
temperature=None,
16451645
fraction_soft=None,
1646+
distill_start_step=0,
16461647
teacher_checkpoint=None,
16471648
initialize_student_weights=False):
16481649
"""Create a StudentTeacher.
@@ -1656,6 +1657,8 @@ def __init__(self,
16561657
target cross entropy to the training loss. The rest of the loss will be
16571658
the cross entropy with the one-hot actual label. Required only when
16581659
training.
1660+
distill_start_step: an int, training steps after which teacher loss is
1661+
incorporated in the overall loss.
16591662
teacher_checkpoint: a string, the path to the teacher checkpoint that we
16601663
wish to use. Required only when training.
16611664
initialize_student_weights: a boolean, if true then initialize any
@@ -1666,6 +1669,7 @@ def __init__(self,
16661669
self.teacher = teacher
16671670
self.temperature = temperature
16681671
self.fraction_soft = fraction_soft
1672+
self.distill_start_step = distill_start_step
16691673
self.teacher_checkpoint = teacher_checkpoint
16701674
self.initialize_student_weights = initialize_student_weights
16711675

@@ -1740,9 +1744,15 @@ def call_simple(self,
17401744
weights = mtf.cast(mtf.greater(targets, 0), soft_loss.dtype)
17411745
soft_loss = (mtf.reduce_sum(soft_loss * weights) /
17421746
self.student.loss_denominator(targets, num_microbatches))
1743-
1744-
loss = (1.0 - self.fraction_soft) * hard_loss \
1745-
+ self.temperature**2 * self.fraction_soft * soft_loss
1747+
global_step = tf.train.get_or_create_global_step()
1748+
current_fraction_soft = tf.cast(
1749+
tf.cond(
1750+
tf.math.greater(global_step, self.distill_start_step),
1751+
lambda: self.fraction_soft, lambda: tf.constant(0.0)),
1752+
dtype=tf.bfloat16)
1753+
1754+
loss = (1.0 - current_fraction_soft) * hard_loss \
1755+
+ self.temperature**2 * current_fraction_soft * soft_loss
17461756

17471757
return student_logits, loss
17481758

0 commit comments

Comments
 (0)