@@ -553,11 +553,11 @@ def update_from_output(
553
553
spec_token_ids = model_runner_output .spec_token_ids
554
554
logprobs = model_runner_output .logprobs
555
555
prompt_logprobs_dict = model_runner_output .prompt_logprobs_dict
556
- spec_decoding_stats = SpecDecodingStats () if self .log_stats else None
557
556
num_scheduled_tokens = scheduler_output .num_scheduled_tokens
558
557
559
558
new_running : list [Request ] = []
560
559
outputs : list [EngineCoreOutput ] = []
560
+ spec_decoding_stats : Optional [SpecDecodingStats ] = None
561
561
562
562
# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
563
563
# loop can be a performance bottleneck. We should do our best to avoid
@@ -585,11 +585,10 @@ def update_from_output(
585
585
num_tokens_rejected = (len (scheduled_spec_token_ids ) + 1 -
586
586
len (generated_token_ids ))
587
587
request .num_computed_tokens -= num_tokens_rejected
588
-
589
- if spec_decoding_stats is not None :
590
- spec_decoding_stats .observe (
591
- num_draft_tokens = len (scheduled_spec_token_ids ),
592
- num_accepted_tokens = len (generated_token_ids ) - 1 )
588
+ spec_decoding_stats = self .make_spec_decoding_stats (
589
+ spec_decoding_stats ,
590
+ num_draft_tokens = len (scheduled_spec_token_ids ),
591
+ num_accepted_tokens = len (generated_token_ids ) - 1 )
593
592
594
593
cached_encoder_input_ids = (
595
594
self .encoder_cache_manager .get_cached_input_ids (request ))
@@ -744,3 +743,17 @@ def make_stats(
744
743
prefix_cache_stats = self .kv_cache_manager .make_prefix_cache_stats (),
745
744
spec_decoding_stats = spec_decoding_stats ,
746
745
)
746
+
747
+ def make_spec_decoding_stats (
748
+ self ,
749
+ spec_decoding_stats : Optional [SpecDecodingStats ],
750
+ num_draft_tokens : int ,
751
+ num_accepted_tokens : int ,
752
+ ) -> Optional [SpecDecodingStats ]:
753
+ if not self .log_stats :
754
+ return None
755
+ if spec_decoding_stats is None :
756
+ spec_decoding_stats = SpecDecodingStats ()
757
+ spec_decoding_stats .observe (num_draft_tokens = num_draft_tokens ,
758
+ num_accepted_tokens = num_accepted_tokens )
759
+ return spec_decoding_stats
0 commit comments