@@ -527,6 +527,11 @@ def _propose(
527
527
input_ids = draft_token_ids_list [- 1 ].int ()
528
528
positions += 1
529
529
530
+ if not is_running_torchair :
531
+ attn_metadata_i .decode .actual_seq_lengths_q = attn_metadata_i .query_start_loc [1 :batch_size + 1 ].tolist ()
532
+ attn_metadata_i .decode .cos = builder .cos_cache [positions ].unsqueeze (1 ).unsqueeze (2 )
533
+ attn_metadata_i .decode .sin = builder .sin_cache [positions ].unsqueeze (1 ).unsqueeze (2 )
534
+
530
535
# NOTE(woosuk): We should handle the case where the draft model
531
536
# generates tokens beyond the max model length. Since it is complex
532
537
# to remove such requests from the batch, we keep them in the batch
@@ -560,6 +565,7 @@ def _propose(
560
565
561
566
if attn_metadata_i .prefill is not None :
562
567
attn_metadata_i .prefill .seq_lens = attn_metadata_i .seq_lens
568
+ attn_metadata_i .prefill .seq_lens_list = attn_metadata_i .prefill .seq_lens .tolist ()
563
569
attn_metadata_i .prefill .context_lens = attn_metadata_i .seq_lens
564
570
attn_metadata_i .prefill .input_positions = self .positions [:
565
571
num_input_tokens ]
@@ -569,6 +575,7 @@ def _propose(
569
575
self .runner .model_config .max_model_len )
570
576
if attn_metadata_i .decode is not None :
571
577
attn_metadata_i .decode .seq_lens = attn_metadata_i .seq_lens
578
+ attn_metadata_i .decode .seq_lens_list = attn_metadata_i .decode .seq_lens .tolist ()
572
579
attn_metadata_i .decode .input_positions = self .positions [:
573
580
num_input_tokens ]
574
581
attn_metadata_i .decode .max_seq_lens += 1
0 commit comments