Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 2a8ea19

Browse files
authored
pr_feedback, ruff formatting
1 parent 30f70b8 commit 2a8ea19

File tree

1 file changed

+12
-58
lines changed

1 file changed

+12
-58
lines changed

dist_run.py

Lines changed: 12 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
get_module_size,
3030
get_num_params,
3131
GPUMemoryMonitor,
32-
TrackTime,
3332
)
3433
from distributed.verification_utils import find_cpu_tensors
3534
from torch.distributed.pipelining import PipelineStage, ScheduleGPipe
@@ -124,7 +123,7 @@ def _encode_strings(
124123
strings: List[str],
125124
tokenizer,
126125
bos: bool = True,
127-
device: str = "cuda",
126+
device: torch.device = "cuda:0",
128127
dtype=torch.int64,
129128
) -> List[torch.Tensor]:
130129
"""Encode a list of prompt strings into a list of tensor token ids."""
@@ -142,7 +141,7 @@ def _create_padded_prompts(
142141
tokenizer,
143142
seqlen: int,
144143
start_pos: int,
145-
device: str,
144+
device: torch.device,
146145
pad_token_id: Optional[int] = None,
147146
) -> Tuple[torch.Tensor, List[int]]:
148147
"""
@@ -284,7 +283,7 @@ def main():
284283

285284
# Distribute model on TP mesh
286285
model.distribute(tp_mesh)
287-
# logger.info(f"Model: {model}")
286+
logger.info(f"Model: {model}")
288287

289288
mbs = 1 # number of micro-batches
290289
mb_size = 1 # micro-batch size
@@ -302,7 +301,7 @@ def main():
302301

303302
# Load weights
304303
logger.info(f"Loading weights for {pp_rank=} on {device=}")
305-
with TrackTime() as timer:
304+
with CUDATrackTime() as timer:
306305
_load_model_weights(model, hf_model_name, device=device, model_config=config)
307306
logger.info(
308307
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for stage {rank}{color.reset}"
@@ -316,7 +315,7 @@ def main():
316315
f"Stage {rank} has {color.blue}{stage_num_params} params{color.reset}, Size: {color.blue}{stage_size_formatted}{color.reset}\n"
317316
)
318317

319-
# Setup input position
318+
# Setup input position (input_pos) for prefill: a list of increasing integers from 0 to seqlen
320319
input_pos = torch.arange(seqlen, device=device)
321320
model.setup_input_pos(input_pos)
322321
model.eval()
@@ -398,57 +397,15 @@ def main():
398397
src = dist.get_global_rank(pp_group, last_pp_group)
399398

400399
# Decoding
401-
num_tokens = 10
402-
"""
403-
with torch.no_grad():
404-
for step in range(num_tokens + 1): # +1 to include the initial prefill step
405-
if pp_rank == 0:
406-
schedule.step(padded_sequence)
407-
dist.recv(x_recv, src, group=pp_group)
408-
logger.info(f"RECEIVED {x_recv=}")
409-
assert x_recv != 128006, f"next_token is header id={x_recv}"
410-
_update_padded_sequence(padded_sequence, x_recv, res, prompt_lengths)
411-
logger.info(
412-
f"Updated padded seq start: {prompt_lengths=}, {padded_sequence[:, prompt_lengths[0]-1]=}"
413-
)
414-
415-
elif pp_rank == last_pp_group:
416-
output = schedule.step()
417-
decode_results = _batch_decode_next_tokens(
418-
output, prompt_lengths, tokenizer
419-
)
420-
logger.info(
421-
f"\n\n{color.green} {'Prefill' if step == 0 else '* Decode *'} responses ====>>>> {color.blue} {decode_results=} \n{color.reset}"
422-
)
423-
424-
next_token = torch.tensor([decode_results[0][0]], device=device)
425-
res.append(decode_results[0][1])
400+
num_tokens = 40
426401

427-
# increment prompt lengths for next token
428-
for i in range(len(prompt_lengths)):
429-
prompt_lengths[i] += 1
430-
logger.info(
431-
f"output review {prompt_lengths[i]=}, {padded_sequence[i, prompt_lengths[i]-1]=}"
432-
)
433-
434-
if step < num_tokens - 1:
435-
dist.send(next_token, dst, pp_group)
436-
logger.info(
437-
f"SENDING back...from rank={pp_rank} to dst={dst}, data {next_token.shape=}, {next_token=}"
438-
)
439-
assert next_token != 128006, f"next_token is header id={next_token}"
440-
441-
else: # middle pp ranks
442-
schedule.step()
443-
"""
444402
with torch.no_grad():
445403
for step in range(num_tokens):
446404
# first
447405
if pp_rank == 0:
448406
schedule.step(padded_sequence)
449407
# only receive if not last step
450408
if step < num_tokens - 1:
451-
452409
dist.recv(
453410
x_recv,
454411
src,
@@ -465,9 +422,10 @@ def main():
465422
decode_results = _batch_decode_next_tokens(
466423
output=output, prompt_lengths=prompt_lengths, tokenizer=tokenizer
467424
)
468-
logger.info(
469-
f"\n\n{color.green} {'Prefill' if step == 0 else '* Decode *'} responses ====>>>> {color.blue} {decode_results=} \n{color.reset}"
470-
)
425+
if tp_rank == 0:
426+
logger.info(
427+
f"\n\n{color.green} {'Prefill' if step == 0 else '* Decode *'} responses ====>>>> {color.blue} {decode_results=} \n{color.reset}"
428+
)
471429

472430
next_token = torch.tensor([decode_results[0][0]], device=device)
473431
res.append(decode_results[0][1])
@@ -479,9 +437,6 @@ def main():
479437
f"output review {prompt_lengths[i]=}, {padded_sequence[i, prompt_lengths[i]-1]=}"
480438
)
481439

482-
# logger.info(f"SENDING back...from {rank=} to {dst=}")
483-
# logger.info(f"SENDING data {next_token.shape=}, {next_token=}")
484-
485440
# only send if not last step
486441
if step < (num_tokens - 1):
487442
dist.send(
@@ -496,11 +451,10 @@ def main():
496451

497452
# output formatted response via last pp group and tp rank 0
498453
if pp_rank == last_pp_group and tp_rank == 0:
499-
logger.info(f"Prompt:{color.green} {prompt[0]} {color.reset}")
454+
logger.info(f"\nPrompt:{color.green} {prompt[0]} {color.reset}")
500455
formatted_response = "".join(res)
501-
logger.info(f"$$$$$$ {color.blue}{formatted_response}{color.reset} $$$$$")
456+
logger.info(f"$$$$$$ {color.blue}{formatted_response}\n{color.reset} $$$$$")
502457

503-
logger.info(f"$$$$$$ {color.red}{res=}{color.reset} $$$$$")
504458
logger.info(
505459
f"{color.green}Success{color.white} - {color.blue}Rank {rank} has completed.{color.reset}"
506460
)

0 commit comments

Comments
 (0)