Skip to content

Commit 82a1b1a

Browse files
authored
[Speculative decoding] Add periodic log with time spent in proposal/scoring/verification (#6963)
1 parent c0d8f16 commit 82a1b1a

File tree

5 files changed

+125
-35
lines changed

5 files changed

+125
-35
lines changed

tests/spec_decode/test_spec_decode_worker.py

Lines changed: 48 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,11 @@ def test_correctly_calls_draft_model(k: int, batch_size: int,
3434
target_worker = mock_worker()
3535
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
3636
worker = SpecDecodeWorker(
37-
draft_worker, target_worker,
38-
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
37+
draft_worker,
38+
target_worker,
39+
mock_spec_decode_sampler(acceptance_sampler_method),
40+
disable_logprobs=False,
41+
metrics_collector=metrics_collector)
3942
exception_secret = 'artificial stop'
4043
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
4144

@@ -74,8 +77,11 @@ def test_correctly_calls_target_model(k: int, batch_size: int,
7477
set_random_seed(1)
7578

7679
worker = SpecDecodeWorker(
77-
draft_worker, target_worker,
78-
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
80+
draft_worker,
81+
target_worker,
82+
mock_spec_decode_sampler(acceptance_sampler_method),
83+
disable_logprobs=False,
84+
metrics_collector=metrics_collector)
7985
worker.init_device()
8086

8187
vocab_size = 32_000
@@ -159,8 +165,11 @@ def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int,
159165

160166
set_random_seed(1)
161167

162-
worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler,
163-
metrics_collector)
168+
worker = SpecDecodeWorker(draft_worker,
169+
target_worker,
170+
spec_decode_sampler,
171+
disable_logprobs=False,
172+
metrics_collector=metrics_collector)
164173
worker.init_device()
165174

166175
proposal_token_ids = torch.randint(low=0,
@@ -249,8 +258,11 @@ def test_correctly_formats_output(k: int, batch_size: int,
249258

250259
set_random_seed(1)
251260
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
252-
worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler,
253-
metrics_collector)
261+
worker = SpecDecodeWorker(draft_worker,
262+
target_worker,
263+
spec_decode_sampler,
264+
disable_logprobs=False,
265+
metrics_collector=metrics_collector)
254266
worker.init_device()
255267

256268
proposal_token_ids = torch.randint(low=0,
@@ -479,9 +491,13 @@ def test_k_equals_zero(k: int, batch_size: int,
479491
set_random_seed(1)
480492

481493
worker = SpecDecodeWorker(
482-
draft_worker, target_worker,
483-
mock_spec_decode_sampler(acceptance_sampler_method), False,
484-
metrics_collector)
494+
proposer_worker=draft_worker,
495+
scorer_worker=target_worker,
496+
spec_decode_sampler=mock_spec_decode_sampler(
497+
acceptance_sampler_method),
498+
disable_logprobs=False,
499+
metrics_collector=metrics_collector,
500+
)
485501

486502
seq_group_metadata_list, _, _ = create_batch(batch_size,
487503
k,
@@ -526,9 +542,13 @@ def test_empty_input_batch(k: int, batch_size: int,
526542
set_random_seed(1)
527543

528544
worker = SpecDecodeWorker(
529-
draft_worker, target_worker,
530-
mock_spec_decode_sampler(acceptance_sampler_method), False,
531-
metrics_collector)
545+
proposer_worker=draft_worker,
546+
scorer_worker=target_worker,
547+
spec_decode_sampler=mock_spec_decode_sampler(
548+
acceptance_sampler_method),
549+
disable_logprobs=False,
550+
metrics_collector=metrics_collector,
551+
)
532552

533553
seq_group_metadata_list, _, _ = create_batch(batch_size,
534554
k,
@@ -560,8 +580,13 @@ def test_init_device(acceptance_sampler_method: str):
560580
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
561581
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
562582

563-
worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler,
564-
False, metrics_collector)
583+
worker = SpecDecodeWorker(
584+
proposer_worker=draft_worker,
585+
scorer_worker=target_worker,
586+
spec_decode_sampler=spec_decode_sampler,
587+
disable_logprobs=False,
588+
metrics_collector=metrics_collector,
589+
)
565590
worker.init_device()
566591

567592
draft_worker.init_device.assert_called_once()
@@ -583,9 +608,11 @@ def test_initialize_cache(acceptance_sampler_method):
583608
target_worker = mock_worker()
584609
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
585610

586-
worker = SpecDecodeWorker(
587-
draft_worker, target_worker,
588-
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
611+
worker = SpecDecodeWorker(proposer_worker=draft_worker,
612+
scorer_worker=target_worker,
613+
spec_decode_sampler=mock_spec_decode_sampler(
614+
acceptance_sampler_method),
615+
metrics_collector=metrics_collector)
589616

590617
kwargs = {"num_gpu_blocks": 1024, "num_cpu_blocks": 1023}
591618
worker.initialize_cache(**kwargs)
@@ -725,7 +752,8 @@ def test_populate_seq_ids_with_bonus_tokens():
725752
seq_group_metadata_list=seq_group_metadata_list,
726753
accepted_token_ids=accepted_token_ids,
727754
target_logprobs=target_token_logprobs,
728-
k=k)
755+
k=k,
756+
stage_times=(0, 0, 0))
729757
# Verify that _seq_with_bonus_token_in_last_step contains the following:
730758
# 1. Sequence IDs that were already present in
731759
# _seq_with_bonus_token_in_last_step but were not part of the current

vllm/config.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -907,6 +907,7 @@ def maybe_create_spec_config(
907907
speculative_max_model_len: Optional[int],
908908
enable_chunked_prefill: bool,
909909
use_v2_block_manager: bool,
910+
disable_log_stats: bool,
910911
speculative_disable_by_batch_size: Optional[int],
911912
ngram_prompt_lookup_max: Optional[int],
912913
ngram_prompt_lookup_min: Optional[int],
@@ -1095,7 +1096,8 @@ def maybe_create_spec_config(
10951096
typical_acceptance_sampler_posterior_threshold,
10961097
typical_acceptance_sampler_posterior_alpha=\
10971098
typical_acceptance_sampler_posterior_alpha,
1098-
disable_logprobs=disable_logprobs
1099+
disable_logprobs=disable_logprobs,
1100+
disable_log_stats=disable_log_stats,
10991101
)
11001102

11011103
@staticmethod
@@ -1189,6 +1191,7 @@ def __init__(
11891191
typical_acceptance_sampler_posterior_threshold: float,
11901192
typical_acceptance_sampler_posterior_alpha: float,
11911193
disable_logprobs: bool,
1194+
disable_log_stats: bool,
11921195
):
11931196
"""Create a SpeculativeConfig object.
11941197
@@ -1221,6 +1224,8 @@ def __init__(
12211224
sampling, target sampling, and after accepted tokens are
12221225
determined. If set to False, log probabilities will be
12231226
returned.
1227+
disable_log_stats: Whether to disable periodic printing of stage
1228+
times in speculative decoding.
12241229
"""
12251230
self.draft_model_config = draft_model_config
12261231
self.draft_parallel_config = draft_parallel_config
@@ -1235,6 +1240,7 @@ def __init__(
12351240
self.typical_acceptance_sampler_posterior_alpha = \
12361241
typical_acceptance_sampler_posterior_alpha
12371242
self.disable_logprobs = disable_logprobs
1243+
self.disable_log_stats = disable_log_stats
12381244

12391245
self._verify_args()
12401246

vllm/engine/arg_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -792,6 +792,7 @@ def create_engine_config(self, ) -> EngineConfig:
792792
speculative_max_model_len=self.speculative_max_model_len,
793793
enable_chunked_prefill=self.enable_chunked_prefill,
794794
use_v2_block_manager=self.use_v2_block_manager,
795+
disable_log_stats=self.disable_log_stats,
795796
ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
796797
ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
797798
draft_token_acceptance_method=\

vllm/spec_decode/spec_decode_worker.py

Lines changed: 54 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
2828
from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker
2929
from vllm.spec_decode.target_model_runner import TargetModelRunner
30-
from vllm.spec_decode.util import (create_sequence_group_output,
30+
from vllm.spec_decode.util import (Timer, create_sequence_group_output,
3131
get_all_num_logprobs,
3232
get_sampled_token_logprobs, nvtx_range,
3333
split_batch_by_proposal_len)
@@ -75,7 +75,9 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
7575
typical_acceptance_sampler_posterior_threshold,
7676
typical_acceptance_sampler_posterior_alpha=speculative_config.
7777
typical_acceptance_sampler_posterior_alpha,
78-
disable_logprobs=speculative_config.disable_logprobs)
78+
disable_logprobs=speculative_config.disable_logprobs,
79+
disable_log_stats=speculative_config.disable_log_stats,
80+
)
7981

8082
return spec_decode_worker
8183

@@ -116,6 +118,7 @@ def create_worker(
116118
typical_acceptance_sampler_posterior_threshold: float,
117119
typical_acceptance_sampler_posterior_alpha: float,
118120
disable_logprobs: bool,
121+
disable_log_stats: bool,
119122
) -> "SpecDecodeWorker":
120123

121124
allow_zero_draft_token_step = True
@@ -171,6 +174,7 @@ def create_worker(
171174
proposer_worker,
172175
scorer_worker,
173176
disable_logprobs=disable_logprobs,
177+
disable_log_stats=disable_log_stats,
174178
disable_by_batch_size=disable_by_batch_size,
175179
spec_decode_sampler=spec_decode_sampler,
176180
allow_zero_draft_token_step=allow_zero_draft_token_step)
@@ -180,7 +184,8 @@ def __init__(
180184
proposer_worker: ProposerWorkerBase,
181185
scorer_worker: WorkerBase,
182186
spec_decode_sampler: SpecDecodeBaseSampler,
183-
disable_logprobs: bool,
187+
disable_logprobs: bool = False,
188+
disable_log_stats: bool = False,
184189
metrics_collector: Optional[AsyncMetricsCollector] = None,
185190
disable_by_batch_size: Optional[int] = None,
186191
allow_zero_draft_token_step: Optional[bool] = True,
@@ -203,6 +208,8 @@ def __init__(
203208
disable_logprobs: If set to True, token log probabilities will
204209
not be output in both the draft worker and the target worker.
205210
If set to False, log probabilities will be output by both.
211+
disable_log_stats: If set to True, disable periodic printing of
212+
speculative stage times.
206213
disable_by_batch_size: If the batch size is larger than this,
207214
disable speculative decoding for new incoming requests.
208215
metrics_collector: Helper class for collecting metrics; can be set
@@ -240,6 +247,7 @@ def __init__(
240247
# in the subsequent step.
241248
self.previous_hidden_states: Optional[HiddenStates] = None
242249
self._disable_logprobs = disable_logprobs
250+
self._disable_log_stats = disable_log_stats
243251

244252
def init_device(self) -> None:
245253
"""Initialize both scorer and proposer models.
@@ -525,28 +533,37 @@ def _run_speculative_decoding_step(
525533
execute_model_req.previous_hidden_states = self.previous_hidden_states
526534
self.previous_hidden_states = None
527535

528-
# Generate proposals using draft worker.
529-
proposals = self.proposer_worker.get_spec_proposals(
530-
execute_model_req, self._seq_with_bonus_token_in_last_step)
536+
with Timer() as proposal_timer:
537+
# Generate proposals using draft worker.
538+
proposals = self.proposer_worker.get_spec_proposals(
539+
execute_model_req, self._seq_with_bonus_token_in_last_step)
531540

532541
if not self._allow_zero_draft_token_step and proposals.no_proposals:
533542
#TODO: Fix it #5814
534543
raise RuntimeError("Cannot handle cases where distributed draft "
535544
"workers generate no tokens")
536545

537-
proposal_scores = self.scorer.score_proposals(
538-
execute_model_req,
539-
proposals,
540-
)
541-
accepted_token_ids, target_logprobs = self._verify_tokens(
542-
execute_model_req.seq_group_metadata_list, proposal_scores,
543-
proposals, execute_model_req.num_lookahead_slots)
546+
with Timer() as scoring_timer:
547+
proposal_scores = self.scorer.score_proposals(
548+
execute_model_req,
549+
proposals,
550+
)
551+
552+
with Timer() as verification_timer:
553+
accepted_token_ids, target_logprobs = self._verify_tokens(
554+
execute_model_req.seq_group_metadata_list, proposal_scores,
555+
proposals, execute_model_req.num_lookahead_slots)
556+
557+
stage_times = (proposal_timer.elapsed_time_ms / num_lookahead_slots,
558+
scoring_timer.elapsed_time_ms,
559+
verification_timer.elapsed_time_ms)
544560

545561
return self._create_output_sampler_list(
546562
execute_model_req.seq_group_metadata_list,
547563
accepted_token_ids,
548564
target_logprobs=target_logprobs,
549-
k=execute_model_req.num_lookahead_slots)
565+
k=execute_model_req.num_lookahead_slots,
566+
stage_times=stage_times)
550567

551568
@nvtx_range("spec_decode_worker._verify_tokens")
552569
def _verify_tokens(
@@ -645,6 +662,7 @@ def _create_output_sampler_list(
645662
accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1]
646663
target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size]
647664
k: int,
665+
stage_times: Tuple[float, float, float],
648666
) -> List[SamplerOutput]:
649667
"""Given the accepted token ids, create a list of SamplerOutput.
650668
@@ -722,8 +740,30 @@ def _create_output_sampler_list(
722740
if maybe_rejsample_metrics is not None:
723741
sampler_output_list[
724742
0].spec_decode_worker_metrics = maybe_rejsample_metrics
743+
744+
# Log time spent in each stage periodically.
745+
# This is periodic because the rejection sampler emits metrics
746+
# periodically.
747+
self._maybe_log_stage_times(*stage_times)
748+
725749
return sampler_output_list
726750

751+
def _maybe_log_stage_times(self, average_time_per_proposal_tok_ms: float,
752+
scoring_time_ms: float,
753+
verification_time_ms: float) -> None:
754+
"""Log the speculative stage times. If stat logging is disabled, do
755+
nothing.
756+
"""
757+
if self._disable_log_stats:
758+
return
759+
760+
logger.info(
761+
"SpecDecodeWorker stage times: "
762+
"average_time_per_proposal_tok_ms=%.02f "
763+
"scoring_time_ms=%.02f verification_time_ms=%.02f",
764+
average_time_per_proposal_tok_ms, scoring_time_ms,
765+
verification_time_ms)
766+
727767
def _create_dummy_logprob_lists(
728768
self,
729769
batch_size: int,

vllm/spec_decode/util.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import time
12
from contextlib import contextmanager
23
from typing import Dict, List, Optional, Tuple
34

@@ -214,3 +215,17 @@ def nvtx_range(msg, *args, **kwargs):
214215
yield
215216
finally:
216217
torch.cuda.nvtx.range_pop()
218+
219+
220+
class Timer:
221+
"""Basic timer context manager for measuring CPU time.
222+
"""
223+
224+
def __enter__(self):
225+
self.start_time = time.time()
226+
return self
227+
228+
def __exit__(self, exc_type, exc_value, traceback):
229+
self.end_time = time.time()
230+
self.elapsed_time_s = self.end_time - self.start_time
231+
self.elapsed_time_ms = self.elapsed_time_s * 1000

0 commit comments

Comments
 (0)