@@ -527,6 +527,14 @@ def _propose(
527
527
input_ids = draft_token_ids_list [- 1 ].int ()
528
528
positions += 1
529
529
530
+ if not self .torchair_graph_enabled :
531
+ attn_metadata_i .decode .actual_seq_lengths_q = attn_metadata_i .query_start_loc [
532
+ 1 :batch_size + 1 ].tolist ()
533
+ attn_metadata_i .decode .cos = builder .cos_cache [
534
+ positions ].unsqueeze (1 ).unsqueeze (2 )
535
+ attn_metadata_i .decode .sin = builder .sin_cache [
536
+ positions ].unsqueeze (1 ).unsqueeze (2 )
537
+
530
538
# NOTE(woosuk): We should handle the case where the draft model
531
539
# generates tokens beyond the max model length. Since it is complex
532
540
# to remove such requests from the batch, we keep them in the batch
@@ -560,6 +568,8 @@ def _propose(
560
568
561
569
if attn_metadata_i .prefill is not None :
562
570
attn_metadata_i .prefill .seq_lens = attn_metadata_i .seq_lens
571
+ attn_metadata_i .prefill .seq_lens_list = attn_metadata_i .prefill .seq_lens .tolist (
572
+ )
563
573
attn_metadata_i .prefill .context_lens = attn_metadata_i .seq_lens
564
574
attn_metadata_i .prefill .input_positions = self .positions [:
565
575
num_input_tokens ]
@@ -569,6 +579,8 @@ def _propose(
569
579
self .runner .model_config .max_model_len )
570
580
if attn_metadata_i .decode is not None :
571
581
attn_metadata_i .decode .seq_lens = attn_metadata_i .seq_lens
582
+ attn_metadata_i .decode .seq_lens_list = attn_metadata_i .decode .seq_lens .tolist (
583
+ )
572
584
attn_metadata_i .decode .input_positions = self .positions [:
573
585
num_input_tokens ]
574
586
attn_metadata_i .decode .max_seq_lens += 1
0 commit comments