Skip to content

Commit f0d6462

Browse files
committed
fix
1 parent d416657 commit f0d6462

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

megatron/training.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,7 @@ def setup_model_and_optimizer_distillation(model_provider_func):
557557

558558
if isinstance(student_model, deepspeed.PipelineEngine):
559559
# hack to get batch_fn from pretrain_gpt.py
560-
student_model.set_batch_fn(model.module._megatron_batch_fn)
560+
student_model.set_batch_fn(student_model.module._megatron_batch_fn)
561561

562562
assert student_model.grid.get_pipe_parallel_rank() == mpu.get_pipeline_model_parallel_rank()
563563
assert student_model.grid.get_slice_parallel_rank() == mpu.get_tensor_model_parallel_rank()
@@ -612,7 +612,7 @@ def setup_model_and_optimizer_distillation(model_provider_func):
612612

613613
if isinstance(teacher_model, deepspeed.PipelineEngine):
614614
# hack to get batch_fn from pretrain_gpt.py
615-
teacher_model.set_batch_fn(model.module._megatron_batch_fn)
615+
teacher_model.set_batch_fn(teacher_model.module._megatron_batch_fn)
616616

617617
assert teacher_model.grid.get_pipe_parallel_rank() == mpu.get_pipeline_model_parallel_rank()
618618
assert teacher_model.grid.get_slice_parallel_rank() == mpu.get_tensor_model_parallel_rank()

0 commit comments

Comments
 (0)