27
27
from vllm .spec_decode .proposer_worker_base import ProposerWorkerBase
28
28
from vllm .spec_decode .smaller_tp_proposer_worker import SmallerTpProposerWorker
29
29
from vllm .spec_decode .target_model_runner import TargetModelRunner
30
- from vllm .spec_decode .util import (create_sequence_group_output ,
30
+ from vllm .spec_decode .util import (Timer , create_sequence_group_output ,
31
31
get_all_num_logprobs ,
32
32
get_sampled_token_logprobs , nvtx_range ,
33
33
split_batch_by_proposal_len )
@@ -75,7 +75,9 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
75
75
typical_acceptance_sampler_posterior_threshold ,
76
76
typical_acceptance_sampler_posterior_alpha = speculative_config .
77
77
typical_acceptance_sampler_posterior_alpha ,
78
- disable_logprobs = speculative_config .disable_logprobs )
78
+ disable_logprobs = speculative_config .disable_logprobs ,
79
+ disable_log_stats = speculative_config .disable_log_stats ,
80
+ )
79
81
80
82
return spec_decode_worker
81
83
@@ -116,6 +118,7 @@ def create_worker(
116
118
typical_acceptance_sampler_posterior_threshold : float ,
117
119
typical_acceptance_sampler_posterior_alpha : float ,
118
120
disable_logprobs : bool ,
121
+ disable_log_stats : bool ,
119
122
) -> "SpecDecodeWorker" :
120
123
121
124
allow_zero_draft_token_step = True
@@ -171,6 +174,7 @@ def create_worker(
171
174
proposer_worker ,
172
175
scorer_worker ,
173
176
disable_logprobs = disable_logprobs ,
177
+ disable_log_stats = disable_log_stats ,
174
178
disable_by_batch_size = disable_by_batch_size ,
175
179
spec_decode_sampler = spec_decode_sampler ,
176
180
allow_zero_draft_token_step = allow_zero_draft_token_step )
@@ -180,7 +184,8 @@ def __init__(
180
184
proposer_worker : ProposerWorkerBase ,
181
185
scorer_worker : WorkerBase ,
182
186
spec_decode_sampler : SpecDecodeBaseSampler ,
183
- disable_logprobs : bool ,
187
+ disable_logprobs : bool = False ,
188
+ disable_log_stats : bool = False ,
184
189
metrics_collector : Optional [AsyncMetricsCollector ] = None ,
185
190
disable_by_batch_size : Optional [int ] = None ,
186
191
allow_zero_draft_token_step : Optional [bool ] = True ,
@@ -203,6 +208,8 @@ def __init__(
203
208
disable_logprobs: If set to True, token log probabilities will
204
209
not be output in both the draft worker and the target worker.
205
210
If set to False, log probabilities will be output by both.
211
+ disable_log_stats: If set to True, disable periodic printing of
212
+ speculative stage times.
206
213
disable_by_batch_size: If the batch size is larger than this,
207
214
disable speculative decoding for new incoming requests.
208
215
metrics_collector: Helper class for collecting metrics; can be set
@@ -240,6 +247,7 @@ def __init__(
240
247
# in the subsequent step.
241
248
self .previous_hidden_states : Optional [HiddenStates ] = None
242
249
self ._disable_logprobs = disable_logprobs
250
+ self ._disable_log_stats = disable_log_stats
243
251
244
252
def init_device (self ) -> None :
245
253
"""Initialize both scorer and proposer models.
@@ -525,28 +533,37 @@ def _run_speculative_decoding_step(
525
533
execute_model_req .previous_hidden_states = self .previous_hidden_states
526
534
self .previous_hidden_states = None
527
535
528
- # Generate proposals using draft worker.
529
- proposals = self .proposer_worker .get_spec_proposals (
530
- execute_model_req , self ._seq_with_bonus_token_in_last_step )
536
+ with Timer () as proposal_timer :
537
+ # Generate proposals using draft worker.
538
+ proposals = self .proposer_worker .get_spec_proposals (
539
+ execute_model_req , self ._seq_with_bonus_token_in_last_step )
531
540
532
541
if not self ._allow_zero_draft_token_step and proposals .no_proposals :
533
542
#TODO: Fix it #5814
534
543
raise RuntimeError ("Cannot handle cases where distributed draft "
535
544
"workers generate no tokens" )
536
545
537
- proposal_scores = self .scorer .score_proposals (
538
- execute_model_req ,
539
- proposals ,
540
- )
541
- accepted_token_ids , target_logprobs = self ._verify_tokens (
542
- execute_model_req .seq_group_metadata_list , proposal_scores ,
543
- proposals , execute_model_req .num_lookahead_slots )
546
+ with Timer () as scoring_timer :
547
+ proposal_scores = self .scorer .score_proposals (
548
+ execute_model_req ,
549
+ proposals ,
550
+ )
551
+
552
+ with Timer () as verification_timer :
553
+ accepted_token_ids , target_logprobs = self ._verify_tokens (
554
+ execute_model_req .seq_group_metadata_list , proposal_scores ,
555
+ proposals , execute_model_req .num_lookahead_slots )
556
+
557
+ stage_times = (proposal_timer .elapsed_time_ms / num_lookahead_slots ,
558
+ scoring_timer .elapsed_time_ms ,
559
+ verification_timer .elapsed_time_ms )
544
560
545
561
return self ._create_output_sampler_list (
546
562
execute_model_req .seq_group_metadata_list ,
547
563
accepted_token_ids ,
548
564
target_logprobs = target_logprobs ,
549
- k = execute_model_req .num_lookahead_slots )
565
+ k = execute_model_req .num_lookahead_slots ,
566
+ stage_times = stage_times )
550
567
551
568
@nvtx_range ("spec_decode_worker._verify_tokens" )
552
569
def _verify_tokens (
@@ -645,6 +662,7 @@ def _create_output_sampler_list(
645
662
accepted_token_ids : torch .Tensor , # shape: [batch_size, k+1]
646
663
target_logprobs : torch .Tensor , # shape: [batch_size, k+1, vocab_size]
647
664
k : int ,
665
+ stage_times : Tuple [float , float , float ],
648
666
) -> List [SamplerOutput ]:
649
667
"""Given the accepted token ids, create a list of SamplerOutput.
650
668
@@ -722,8 +740,30 @@ def _create_output_sampler_list(
722
740
if maybe_rejsample_metrics is not None :
723
741
sampler_output_list [
724
742
0 ].spec_decode_worker_metrics = maybe_rejsample_metrics
743
+
744
+ # Log time spent in each stage periodically.
745
+ # This is periodic because the rejection sampler emits metrics
746
+ # periodically.
747
+ self ._maybe_log_stage_times (* stage_times )
748
+
725
749
return sampler_output_list
726
750
751
+ def _maybe_log_stage_times (self , average_time_per_proposal_tok_ms : float ,
752
+ scoring_time_ms : float ,
753
+ verification_time_ms : float ) -> None :
754
+ """Log the speculative stage times. If stat logging is disabled, do
755
+ nothing.
756
+ """
757
+ if self ._disable_log_stats :
758
+ return
759
+
760
+ logger .info (
761
+ "SpecDecodeWorker stage times: "
762
+ "average_time_per_proposal_tok_ms=%.02f "
763
+ "scoring_time_ms=%.02f verification_time_ms=%.02f" ,
764
+ average_time_per_proposal_tok_ms , scoring_time_ms ,
765
+ verification_time_ms )
766
+
727
767
def _create_dummy_logprob_lists (
728
768
self ,
729
769
batch_size : int ,
0 commit comments