Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 8d01d9b

Browse files
authored
[Distributed] Separate prefill and decode (#1162)
1 parent e27e162 commit 8d01d9b

File tree

1 file changed

+112
-78
lines changed

1 file changed

+112
-78
lines changed

dist_run.py

Lines changed: 112 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,8 @@ def _create_padded_prompts(
187187

188188
def _batch_decode_next_tokens(
189189
output: torch.Tensor,
190-
prompt_lengths: List[int],
191190
tokenizer,
191+
prompt_lengths: Optional[List[int]] = None,
192192
) -> List[Tuple[int, str]]:
193193
"""
194194
Decode the next token for each prompt in the batch.
@@ -201,7 +201,8 @@ def _batch_decode_next_tokens(
201201
results = []
202202

203203
for i in range(batch_size):
204-
next_token_logits = output[i, prompt_lengths[i] - 1, :]
204+
pos = prompt_lengths[i] - 1 if prompt_lengths is not None else 0
205+
next_token_logits = output[i, pos, :]
205206

206207
# Argmax (deterministic) TODO: add temperature
207208
next_token = torch.argmax(next_token_logits, dim=-1)
@@ -276,6 +277,10 @@ def main(args):
276277
tp_group_size = tp_group.size()
277278
logger.info(f"{pp_group_size=}, {tp_group_size=}")
278279

280+
# Convenience variables
281+
first_pp_rank = 0
282+
last_pp_rank = pp_group_size - 1
283+
279284
# Assuming same number of GPUs per node
280285
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
281286

@@ -293,29 +298,23 @@ def main(args):
293298
logger.info(f"Model: {model}")
294299

295300
mbs = 1 # number of micro-batches
296-
mb_size = 5 # micro-batch size
301+
mb_size = 4 # micro-batch size
297302
batch_size = mbs * mb_size # total batch size
298303

299-
seqlen = 4096 # sequence length
304+
seqlen_prefill = 1024 # sequence length
300305
dim = 4096 # embedding dimension
301306

302307
# Setup KV caches (after model distribution)
303308
# TODO: the setting below only works for 1 micro-batch case. To support
304309
# multiple micro-batches, we need the KV cache in the model to be aware of
305310
# the number of micro-batches and the current micro-batch index.
306-
model.setup_caches(mb_size, seqlen)
307-
308-
mb_ids = torch.randint(0, config.vocab_size, (mb_size, seqlen), device=device)
309-
activation = torch.rand(
310-
mb_size, seqlen, dim, device=device, dtype=model_dtype
311-
)
312-
example_args = mb_ids if pp_rank == 0 else activation
311+
model.setup_caches(mb_size, seqlen_prefill)
313312

314313
# Load weights
315314
logger.info(f"Loading weights for {pp_rank=} on {device=}")
316-
317315
with CUDATrackTime() as timer:
318316
_load_model_weights(model, distribution, device=device, model_config=config)
317+
model.to(device)
319318

320319
logger.info(
321320
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for stage {rank}{color.reset}"
@@ -330,53 +329,47 @@ def main(args):
330329
)
331330

332331
# Setup input position (input_pos) for prefill: a list of increasing integers from 0 to seqlen
333-
input_pos = torch.arange(seqlen, device=device)
332+
input_pos = torch.arange(seqlen_prefill, device=device)
334333
model.setup_input_pos(input_pos)
335334
model.eval()
336335

337-
logger.info(f"Creating pipeline stage {pp_rank=}, {pp_degree=}")
338-
stage = PipelineStage(
336+
# Helper function to get example inputs and outputs for the stages.
337+
def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
338+
mb_ids = torch.randint(0, config.vocab_size, (mb_size, seqlen), device=device)
339+
activation = torch.rand(
340+
mb_size, seqlen, dim, device=device, dtype=model_dtype
341+
)
342+
logits = torch.rand(
343+
mb_size, seqlen, config.vocab_size, device=device, dtype=model_dtype
344+
)
345+
example_inputs = (mb_ids if pp_rank == first_pp_rank else activation,)
346+
example_outputs = (logits if pp_rank == last_pp_rank else activation,)
347+
return example_inputs, example_outputs
348+
349+
# Create prefill stage
350+
logger.info(f"Creating pipeline stage for prefill {pp_rank=}, {pp_degree=}")
351+
example_inputs, example_outputs = get_example_ins_outs(seqlen_prefill)
352+
prefill_stage = PipelineStage(
339353
model,
340354
pp_rank,
341355
pp_degree,
342356
device,
343-
input_args=(example_args,),
357+
input_args=example_inputs,
358+
output_args=example_outputs,
344359
group=pp_group,
345360
)
361+
# create schedule
362+
prefill_schedule = ScheduleGPipe(prefill_stage, mbs)
346363

347364
prompt = [
348-
"What is snow?",
349-
"Where does Santa Claus live?",
350-
"What is PyTorch?",
351-
"Write a poem about the beauty of the night sky.",
352-
"What is the capital of France, Germany and Switzerland?",
353-
]
354-
355-
"""
356-
"What is the capital of France?",
357-
"What is your name?",
358-
"What is the capital of Japan?",
359-
"When is Christmas?",
360-
"Where does Santa Claus live?",
361-
"What is the capital of the United States?",
362-
"What is the capital of China?",
363-
"What is the capital of Russia?",
364-
"What is PyTorch?",
365-
"What is the capital of India?",
366-
"What is an LLM?",
367-
"What is the capital of Brazil?",
368-
"What is the capital of Mexico?",
369-
"What is the capital of Argentina?",
370-
"What is the capital of Canada?",
365+
"What is a computer?",
366+
"Where does Santa live?",
367+
"Who is Abraham Lincoln?",
368+
"How are models trained?",
371369
]
372-
"""
373370

374371
start_pos = 0
375372

376-
# pipeline comms setup
377-
first_pp_rank = 0
378-
last_pp_rank = pp_group_size - 1
379-
380373
# Need these global ids due to the API definition of dist.send and recv
381374
first_pp_rank_global_id = dist.get_global_rank(pp_group, first_pp_rank)
382375
last_pp_rank_global_id = dist.get_global_rank(pp_group, last_pp_rank)
@@ -388,15 +381,14 @@ def main(args):
388381

389382
# create a padded tensor for the input prompt
390383
padded_sequence, prompt_lengths = _create_padded_prompts(
391-
input_ids, tokenizer, seqlen, start_pos, device
384+
input_ids, tokenizer, seqlen_prefill, start_pos, device
392385
)
393-
394-
# create schedule
395-
schedule = ScheduleGPipe(stage, mbs)
386+
# TODO: figure out how to set input_pos for each prompt in the batch then we
387+
# can remove this limitation.
388+
s = set(prompt_lengths)
389+
assert len(s) == 1, f"prompt_lengths should be the same, got {s}"
396390

397391
# with CUDATrackTime() as timer:
398-
first_pp_rank = 0
399-
last_pp_rank = pp_group_size - 1
400392
# Need these global ids due to the API definition of dist.send and recv
401393
first_pp_rank_global_id = dist.get_global_rank(pp_group, first_pp_rank)
402394
last_pp_rank_global_id = dist.get_global_rank(pp_group, last_pp_rank)
@@ -408,25 +400,87 @@ def main(args):
408400
res = [[] for _ in range(total_prompts)]
409401
num_tokens = 40
410402

403+
# Prefill phase
404+
# Run context input through pipeline, in 1 step
405+
with torch.no_grad():
406+
if pp_rank == first_pp_rank:
407+
output = prefill_schedule.step(padded_sequence)
408+
elif pp_rank == last_pp_rank:
409+
output = prefill_schedule.step()
410+
else: # middle pp ranks
411+
prefill_schedule.step()
412+
413+
# Decode the output -- first generated token
414+
if pp_rank == last_pp_rank:
415+
decode_results = _batch_decode_next_tokens(
416+
output=output,
417+
tokenizer=tokenizer,
418+
prompt_lengths=prompt_lengths,
419+
)
420+
for i in range(len(decode_results)):
421+
new_token[i, 0] = torch.tensor(
422+
[decode_results[i][0]], device=device
423+
) # token_id in int form
424+
if tp_rank == 0:
425+
logger.info(
426+
f"{color.green} {'* Prefill *'} "
427+
f"responses ====>>>> {color.blue} {decode_results=}{color.reset}"
428+
)
429+
430+
# seqlen = 1 now
431+
seqlen_decode = 1
432+
input_pos = torch.tensor([prompt_lengths[0]], device=device)
433+
model.setup_input_pos(input_pos)
434+
435+
# Create decode stage
436+
logger.info(f"Creating pipeline stage for decode {pp_rank=}, {pp_degree=}")
437+
example_inputs, example_outputs = get_example_ins_outs(seqlen_decode)
438+
decode_stage = PipelineStage(
439+
model,
440+
pp_rank,
441+
pp_degree,
442+
device,
443+
input_args=example_inputs,
444+
output_args=example_outputs,
445+
group=pp_group,
446+
)
447+
# create schedule
448+
decode_schedule = ScheduleGPipe(decode_stage, mbs)
449+
411450
# Decoding
412451
with torch.no_grad():
413-
for step in range(num_tokens):
452+
for step in range(num_tokens - 1):
453+
# sendrecv between last and first ranks, only if:
454+
# first_pp_rank != last_pp_rank.
455+
if pp_rank == last_pp_rank and pp_rank != first_pp_rank:
456+
dist.send(
457+
new_token,
458+
dst=first_pp_rank_global_id,
459+
group=pp_group,
460+
)
461+
elif pp_rank == first_pp_rank and pp_rank != last_pp_rank:
462+
dist.recv(
463+
new_token,
464+
src=last_pp_rank_global_id,
465+
group=pp_group,
466+
)
467+
414468
# Run data through pipeline
415469
if pp_rank == first_pp_rank:
416-
output = schedule.step(padded_sequence)
470+
output = decode_schedule.step(new_token)
417471
elif pp_rank == last_pp_rank:
418-
output = schedule.step()
472+
output = decode_schedule.step()
419473
else: # middle pp ranks
420-
schedule.step()
474+
decode_schedule.step()
421475

422476
# Decode the output
423477
if pp_rank == last_pp_rank:
424478
decode_results = _batch_decode_next_tokens(
425-
output=output, prompt_lengths=prompt_lengths, tokenizer=tokenizer
479+
output=output, tokenizer=tokenizer
426480
)
427481
if tp_rank == 0:
428482
logger.info(
429-
f"{color.green} {'Prefill' if step == 0 else '* Decode *'} "
483+
f"{color.green} {'* Decode *'} "
430484
f"responses ====>>>> {color.blue} {decode_results=}{color.reset}"
431485
)
432486
# decode results returns both token_id (int) and token_str (readable), hence [0] and [1]
@@ -436,28 +490,8 @@ def main(args):
436490
[decode_results[i][0]], device=device
437491
) # decode_results[i][0]
438492

439-
# sendrecv between last and first ranks, only if:
440-
# first_pp_rank != last_pp_rank.
441-
if pp_rank == last_pp_rank and pp_rank != first_pp_rank:
442-
dist.send(
443-
new_token,
444-
dst=first_pp_rank_global_id,
445-
group=pp_group,
446-
)
447-
elif pp_rank == first_pp_rank and pp_rank != last_pp_rank:
448-
dist.recv(
449-
new_token,
450-
src=last_pp_rank_global_id,
451-
group=pp_group,
452-
)
453-
454-
# Update input sequence with new token
455-
if pp_rank == first_pp_rank:
456-
_update_padded_sequence(padded_sequence, new_token, prompt_lengths)
457-
458-
# increment prompt lengths for next token
459-
for i in range(len(prompt_lengths)):
460-
prompt_lengths[i] += 1
493+
input_pos += 1
494+
model.setup_input_pos(input_pos)
461495

462496
# Display the decoding results
463497

0 commit comments

Comments
 (0)