Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 5951e39

Browse files
committed
Add timer
1 parent 4ecb951 commit 5951e39

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

dist_run.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
394394
s = set(prompt_lengths)
395395
assert len(s) == 1, f"prompt_lengths should be the same, got {s}"
396396

397-
# with CUDATrackTime() as timer:
398397
# Need these global ids due to the API definition of dist.send and recv
399398
first_pp_rank_global_id = dist.get_global_rank(pp_group, first_pp_rank)
400399
last_pp_rank_global_id = dist.get_global_rank(pp_group, last_pp_rank)
@@ -411,14 +410,18 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
411410
# TODO: we need to pass `input_pos` and `cache_lane` to each stage.
412411
lane = 0
413412
kwargs = {"input_pos": input_pos, "cache_lane": lane}
414-
with torch.no_grad():
413+
with torch.no_grad(), CUDATrackTime() as timer:
415414
if pp_rank == first_pp_rank:
416415
output = prefiller.step(padded_sequence, **kwargs)
417416
elif pp_rank == last_pp_rank:
418417
output = prefiller.step(**kwargs)
419418
else: # middle pp ranks
420419
prefiller.step(**kwargs)
421420

421+
logger.info(
422+
f"{color.green}Prefilling time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"
423+
)
424+
422425
# Decode the output -- first generated token
423426
if pp_rank == last_pp_rank:
424427
decode_results = _batch_decode_next_tokens(
@@ -456,7 +459,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
456459
decorder = ScheduleGPipe(decode_stage, mbs)
457460

458461
# Decoding
459-
with torch.no_grad():
462+
with torch.no_grad(), CUDATrackTime() as timer:
460463
for step in range(num_tokens - 1):
461464
kwargs = {"input_pos": input_pos, "cache_lane": lane}
462465
# sendrecv between last and first ranks, only if:
@@ -501,6 +504,10 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
501504

502505
input_pos += 1
503506

507+
logger.info(
508+
f"{color.green}Decoding time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"
509+
)
510+
504511
# Display the decoding results
505512

506513
# output formatted response via last pp group and tp rank 0

0 commit comments

Comments
 (0)