File tree Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -557,7 +557,7 @@ def setup_model_and_optimizer_distillation(model_provider_func):
557557
558558 if isinstance (student_model , deepspeed .PipelineEngine ):
559559 # hack to get batch_fn from pretrain_gpt.py
560- student_model .set_batch_fn (model .module ._megatron_batch_fn )
560+ student_model .set_batch_fn (student_model .module ._megatron_batch_fn )
561561
562562 assert student_model .grid .get_pipe_parallel_rank () == mpu .get_pipeline_model_parallel_rank ()
563563 assert student_model .grid .get_slice_parallel_rank () == mpu .get_tensor_model_parallel_rank ()
@@ -612,7 +612,7 @@ def setup_model_and_optimizer_distillation(model_provider_func):
612612
613613 if isinstance (teacher_model , deepspeed .PipelineEngine ):
614614 # hack to get batch_fn from pretrain_gpt.py
615- teacher_model .set_batch_fn (model .module ._megatron_batch_fn )
615+ teacher_model .set_batch_fn (teacher_model .module ._megatron_batch_fn )
616616
617617 assert teacher_model .grid .get_pipe_parallel_rank () == mpu .get_pipeline_model_parallel_rank ()
618618 assert teacher_model .grid .get_slice_parallel_rank () == mpu .get_tensor_model_parallel_rank ()
You can’t perform that action at this time.
0 commit comments