@@ -394,7 +394,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
394394 s = set (prompt_lengths )
395395 assert len (s ) == 1 , f"prompt_lengths should be the same, got { s } "
396396
397- # with CUDATrackTime() as timer:
398397 # Need these global ids due to the API definition of dist.send and recv
399398 first_pp_rank_global_id = dist .get_global_rank (pp_group , first_pp_rank )
400399 last_pp_rank_global_id = dist .get_global_rank (pp_group , last_pp_rank )
@@ -411,14 +410,18 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
411410 # TODO: we need to pass `input_pos` and `cache_lane` to each stage.
412411 lane = 0
413412 kwargs = {"input_pos" : input_pos , "cache_lane" : lane }
414- with torch .no_grad ():
413+ with torch .no_grad (), CUDATrackTime () as timer :
415414 if pp_rank == first_pp_rank :
416415 output = prefiller .step (padded_sequence , ** kwargs )
417416 elif pp_rank == last_pp_rank :
418417 output = prefiller .step (** kwargs )
419418 else : # middle pp ranks
420419 prefiller .step (** kwargs )
421420
421+ logger .info (
422+ f"{ color .green } Prefilling time: { timer .get_time ()} { timer .unit } for rank { rank } { color .reset } "
423+ )
424+
422425 # Decode the output -- first generated token
423426 if pp_rank == last_pp_rank :
424427 decode_results = _batch_decode_next_tokens (
@@ -456,7 +459,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
456459 decorder = ScheduleGPipe (decode_stage , mbs )
457460
458461 # Decoding
459- with torch .no_grad ():
462+ with torch .no_grad (), CUDATrackTime () as timer :
460463 for step in range (num_tokens - 1 ):
461464 kwargs = {"input_pos" : input_pos , "cache_lane" : lane }
462465 # sendrecv between last and first ranks, only if:
@@ -501,6 +504,10 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
501504
502505 input_pos += 1
503506
507+ logger .info (
508+ f"{ color .green } Decoding time: { timer .get_time ()} { timer .unit } for rank { rank } { color .reset } "
509+ )
510+
504511 # Display the decoding results
505512
506513 # output formatted response via last pp group and tp rank 0
0 commit comments