Skip to content

Commit 3d82ca6

Browse files
committed
[PP] Fix PP meta init
Uses meta device for tensors/model used before pipeline splitting. *Important:* Relies on pytorch/pytorch#136243 to make PipelineStage avoid materializing the model and the input/output buffers eagerly. Relies on existing .to(device) in train.py to finally materialize the model. ghstack-source-id: c15282c Pull Request resolved: #588
1 parent 4b3f2e4 commit 3d82ca6

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

torchtitan/parallelisms/pipeline_llama.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,11 @@ def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=Fal
104104
model.norm = None
105105
model.output = None
106106

107-
# TODO(whc) once ManualPipelineStage supports lazy shape inference, we can leave model on meta device longer and
108-
# get rid of the input shape hardcoded here. For now, it should not be a big deal since we only materialize the
109-
# layers of the model that map to this stage, not the whole model.
107+
# Note: these tensors are only here as metadata hints, so pipelining runtime knows what size buffer to allocate.
108+
# these tensors should be on meta device, adn the model should also. It will be allocated on device after
109+
# applying all other parallelisms.
110+
111+
# TODO(whc) once ManualPipelineStage supports lazy shape inference, we can avoid specifying input/output shapes
110112
mp_dtype = _mixed_precision_dtype(job_config, parallel_dims)
111113
batch_size = job_config.training.batch_size
112114
local_seq_len = int(job_config.training.seq_len // parallel_dims.tp)
@@ -117,18 +119,17 @@ def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=Fal
117119
model_config.vocab_size,
118120
)
119121
if is_first:
120-
(input,) = _llama_trace_input(job_config, model_config, device=device)
122+
(input,) = _llama_trace_input(job_config, model_config, device="meta")
121123
else:
122124
# later layers (assume all start w/ a transformer layer)
123-
input = torch.rand(layers_io_shape, dtype=mp_dtype, device=device)
125+
input = torch.rand(layers_io_shape, dtype=mp_dtype, device="meta")
124126

125127
if is_last:
126-
output = torch.rand(output_layer_shape, dtype=torch.float32, device=device)
128+
output = torch.rand(output_layer_shape, dtype=torch.float32, device="meta")
127129
else:
128130
# earlier layers (assume all end in a transformer layer)
129-
output = torch.rand(layers_io_shape, dtype=mp_dtype, device=device)
131+
output = torch.rand(layers_io_shape, dtype=mp_dtype, device="meta")
130132

131-
model.to_empty(device=device)
132133
stage = PipelineStage(
133134
model,
134135
stage_idx,

0 commit comments

Comments
 (0)