Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 7708646

Browse files
authored
[Distributed] fix pp=1 case; clean up (#1149)
1 parent a645f8e commit 7708646

File tree

1 file changed

+60
-85
lines changed

1 file changed

+60
-85
lines changed

dist_run.py

Lines changed: 60 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,8 @@ def _load_model_weights(stage_module, distribution, device, model_config):
135135
def _encode_strings(
136136
strings: List[str],
137137
tokenizer,
138-
bos: bool = True,
139-
device: torch.device = "cuda:0",
138+
bos: bool,
139+
device: torch.device,
140140
dtype=torch.int64,
141141
) -> List[torch.Tensor]:
142142
"""Encode a list of prompt strings into a list of tensor token ids."""
@@ -216,13 +216,13 @@ def _batch_decode_next_tokens(
216216

217217
def _update_padded_sequence(
218218
padded_sequence: torch.Tensor,
219-
x_recv: torch.Tensor,
220-
res,
219+
new_token: torch.Tensor,
221220
prompt_lengths: List[int],
222221
) -> None:
222+
# TODO: this is a hacky way to update the padded sequence: when there is
223+
# more than one prompt, the for loop and the assignment is incompatible.
223224
for i in range(len(prompt_lengths)):
224-
prompt_lengths[i] += 1
225-
padded_sequence[i, prompt_lengths[i] - 1] = x_recv
225+
padded_sequence[i, prompt_lengths[i]] = new_token
226226

227227

228228
def _cleanup():
@@ -267,19 +267,15 @@ def main(args):
267267
device_mesh = _create_device_mesh(mesh_dimensions)
268268
tp_mesh = device_mesh["tp"]
269269
pp_mesh = device_mesh["pp"]
270+
logger.info(f"Created device mesh: {device_mesh}\n{tp_mesh=}, {pp_mesh=}")
271+
270272
tp_rank = tp_mesh.get_local_rank()
271273
pp_rank = pp_mesh.get_local_rank()
272274
tp_group = tp_mesh.get_group()
273275
pp_group = pp_mesh.get_group()
274-
275-
logger.info(f"review: {pp_group=}, {tp_group= }")
276-
277-
logger.info(f"Created device mesh: {device_mesh}\n {tp_mesh=}, {pp_mesh=}\n")
278-
# TODO - this assumes 1D mesh, need to update for 2D+ mesh
279-
pp_group_size = pp_mesh.size()
280-
tp_group_size = tp_mesh.size()
281-
282-
logger.info(f"pp_group_size: {pp_group_size}, tp_group_size: {tp_group_size}")
276+
pp_group_size = pp_group.size()
277+
tp_group_size = tp_group.size()
278+
logger.info(f"{pp_group_size=}, {tp_group_size=}")
283279

284280
# Assuming same number of GPUs per node
285281
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
@@ -316,7 +312,7 @@ def main(args):
316312
logger.info(f"Loading weights for {pp_rank=} on {device=}")
317313

318314
with CUDATrackTime() as timer:
319-
_load_model_weights(model, hf_model_name, device=device, model_config=config)
315+
_load_model_weights(model, distribution, device=device, model_config=config)
320316

321317
logger.info(
322318
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for stage {rank}{color.reset}"
@@ -327,7 +323,7 @@ def main(args):
327323
stage_size_formatted = bytes_to_readable(stage_size)
328324
stage_num_params = get_num_params(model)
329325
logger.info(
330-
f"Stage {rank} has {color.blue}{stage_num_params} params{color.reset}, Size: {color.blue}{stage_size_formatted}{color.reset}\n"
326+
f"Stage {rank} has {color.blue}{stage_num_params} params{color.reset}, Size: {color.blue}{stage_size_formatted}{color.reset}"
331327
)
332328

333329
# Setup input position (input_pos) for prefill: a list of increasing integers from 0 to seqlen
@@ -342,15 +338,9 @@ def main(args):
342338
pp_degree,
343339
device,
344340
input_args=(example_args,),
345-
group=pp_mesh.get_group(),
341+
group=pp_group,
346342
)
347343

348-
# this check confirms that there are no cpu tensors in the model..we expect this to be true.
349-
cpu_tensors = find_cpu_tensors(stage.submod)
350-
# logger.info(f"Found {len(cpu_tensors)} cpu tensors: {cpu_tensors}")
351-
if len(cpu_tensors) > 0:
352-
raise ValueError("Found cpu tensors in stage")
353-
354344
prompt = [
355345
"What is snow?",
356346
]
@@ -374,7 +364,6 @@ def main(args):
374364
]
375365
"""
376366

377-
378367
start_pos = 0
379368

380369
# encode the prompt
@@ -388,88 +377,74 @@ def main(args):
388377
input_ids, tokenizer, seqlen, start_pos, device
389378
)
390379
logger.info(f"{prompt_lengths=}")
391-
logger.info(f"first prompt {padded_sequence[0, :prompt_lengths[0]+1]=}")
392-
if len(prompt_lengths) > 1:
393-
logger.info(f"second prompt {padded_sequence[1, :prompt_lengths[1]+1]=}")
394380

381+
# create schedule
395382
schedule = ScheduleGPipe(stage, mbs)
396-
logger.info(f"Created schedule: {schedule}")
397383

398384
# with CUDATrackTime() as timer:
399-
first_pp_group = 0
400-
last_pp_group = pp_group_size - 1
401-
402-
x_recv = torch.zeros(1, device=device, dtype=torch.int64)
403-
logger.info(f"{x_recv.shape=}")
385+
first_pp_rank = 0
386+
last_pp_rank = pp_group_size - 1
404387

405-
last_global_rank = world_size - 1
388+
# New token generated each iteration
389+
new_token = torch.zeros(1, device=device, dtype=torch.int64)
406390
res = []
407-
dst = None
408-
src = None
409-
410-
if pp_rank == last_pp_group:
411-
dst = dist.get_global_rank(pp_group, 0)
412-
elif pp_rank == 0:
413-
src = dist.get_global_rank(pp_group, last_pp_group)
414-
415-
# Decoding
416391
num_tokens = 40
417392

393+
# Decoding
418394
with torch.no_grad():
419395
for step in range(num_tokens):
420-
# first
421-
if pp_rank == 0:
422-
schedule.step(padded_sequence)
423-
# only receive if not last step
424-
if step < num_tokens - 1:
425-
dist.recv(
426-
x_recv,
427-
src,
428-
group=pp_group,
429-
)
430-
_update_padded_sequence(
431-
padded_sequence, x_recv, res, prompt_lengths
432-
)
433-
434-
# last
435-
elif pp_rank == last_pp_group:
396+
# Run data through pipeline
397+
if pp_rank == first_pp_rank:
398+
output = schedule.step(padded_sequence)
399+
elif pp_rank == last_pp_rank:
436400
output = schedule.step()
437-
# need to decode the output
401+
else: # middle pp ranks
402+
schedule.step()
403+
404+
# Decode the output
405+
if pp_rank == last_pp_rank:
438406
decode_results = _batch_decode_next_tokens(
439407
output=output, prompt_lengths=prompt_lengths, tokenizer=tokenizer
440408
)
441409
if tp_rank == 0:
442410
logger.info(
443-
f"\n\n{color.green} {'Prefill' if step == 0 else '* Decode *'} responses ====>>>> {color.blue} {decode_results=} \n{color.reset}"
411+
f"{color.green} {'Prefill' if step == 0 else '* Decode *'} "
412+
f"responses ====>>>> {color.blue} {decode_results=}{color.reset}"
444413
)
445-
446-
next_token = torch.tensor([decode_results[0][0]], device=device)
414+
# decode results returns both token_id (int) and token_str (readable), hence [0] and [1]
415+
new_token = torch.tensor([decode_results[0][0]], device=device)
447416
res.append(decode_results[0][1])
448417

449-
# increment prompt lengths for next token
450-
for i in range(len(prompt_lengths)):
451-
prompt_lengths[i] += 1
452-
# logger.info(
453-
# f"output review {prompt_lengths[i]=}, {padded_sequence[i, prompt_lengths[i]-1]=}"
454-
# )
455-
456-
# only send if not last step
457-
if step < (num_tokens - 1):
458-
dist.send(
459-
next_token,
460-
dst,
461-
pp_group,
462-
)
418+
# sendrecv between last and first ranks, only if:
419+
# first_pp_rank != last_pp_rank.
420+
if pp_rank == last_pp_rank and pp_rank != first_pp_rank:
421+
dist.send(
422+
new_token,
423+
dst=dist.get_global_rank(pp_group, first_pp_rank),
424+
group=pp_group,
425+
)
426+
elif pp_rank == first_pp_rank and pp_rank != last_pp_rank:
427+
dist.recv(
428+
new_token,
429+
src=dist.get_global_rank(pp_group, last_pp_rank),
430+
group=pp_group,
431+
)
463432

464-
# middle pp ranks
465-
else:
466-
schedule.step()
433+
# Update input sequence with new token
434+
if pp_rank == first_pp_rank:
435+
_update_padded_sequence(
436+
padded_sequence, new_token, prompt_lengths
437+
)
438+
439+
# increment prompt lengths for next token
440+
for i in range(len(prompt_lengths)):
441+
prompt_lengths[i] += 1
467442

468443
# output formatted response via last pp group and tp rank 0
469-
if pp_rank == last_pp_group and tp_rank == 0:
470-
logger.info(f"\nPrompt:{color.green} {prompt[0]} {color.reset}")
471-
formatted_response = "".join(res)
472-
logger.info(f"$$$$$$ {color.blue}{formatted_response}\n{color.reset} $$$$$")
444+
if pp_rank == last_pp_rank and tp_rank == 0:
445+
logger.info(f"Prompt:{color.green} {prompt[0]} {color.reset}")
446+
formatted_response = " ".join(res)
447+
logger.info(f"$$$$$$ {color.blue}{formatted_response} {color.reset} $$$$$")
473448

474449
logger.info(
475450
f"{color.green}Success{color.white} - {color.blue}Rank {rank} has completed.{color.reset}"

0 commit comments

Comments
 (0)