@@ -528,19 +528,20 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
528528 start_token_index :end_token_index ] = new_token_ids
529529 self .input_batch .num_tokens_no_spec [
530530 req_index ] = end_token_index
531- # Add spec_token_ids to token_ids_cpu.
532- spec_token_ids = (
533- scheduler_output .scheduled_spec_decode_tokens .get (
534- req_id , ()))
535- if spec_token_ids :
536- start_index = end_token_index
537- end_token_index += len (spec_token_ids )
538- self .input_batch .token_ids_cpu [
539- req_index ,
540- start_index :end_token_index ] = spec_token_ids
541- # NOTE(woosuk): `num_tokens` here may include spec tokens.
542531 self .input_batch .num_tokens [req_index ] = end_token_index
543532
533+ # Add spec_token_ids to token_ids_cpu.
534+ spec_token_ids = (
535+ scheduler_output .scheduled_spec_decode_tokens .get (req_id , ()))
536+ if spec_token_ids :
537+ num_spec_tokens = len (spec_token_ids )
538+ start_index = self .input_batch .num_tokens_no_spec [req_index ]
539+ end_token_index = start_index + num_spec_tokens
540+ self .input_batch .token_ids_cpu [
541+ req_index , start_index :end_token_index ] = spec_token_ids
542+ # NOTE(woosuk): `num_tokens` here may include spec tokens.
543+ self .input_batch .num_tokens [req_index ] += num_spec_tokens
544+
544545 # Add the new or resumed requests to the persistent batch.
545546 # The smaller empty indices are filled first.
546547 for req_id in req_ids_to_add :
0 commit comments