@@ -757,10 +757,19 @@ def _prepare_inputs(
757
757
# Prepare the attention metadata.
758
758
self .query_start_loc_np [0 ] = 0
759
759
self .query_start_loc_np [1 :num_reqs + 1 ] = cu_num_tokens
760
+ # Note: pad query_start_loc to be non-decreasing, as kernels
761
+ # like FlashAttention requires that
762
+ self .query_start_loc_np [num_reqs + 1 :].fill (cu_num_tokens [- 1 ])
763
+ self .query_start_loc .copy_ (self .query_start_loc_cpu , non_blocking = True )
764
+ query_start_loc = self .query_start_loc [:num_reqs + 1 ]
760
765
761
766
self .seq_lens_np [:num_reqs ] = (
762
767
self .input_batch .num_computed_tokens_cpu [:num_reqs ] +
763
768
num_scheduled_tokens )
769
+ # Fill unused with 0 for full cuda graph mode.
770
+ self .seq_lens_np [num_reqs :].fill (0 )
771
+ self .seq_lens .copy_ (self .seq_lens_cpu , non_blocking = True )
772
+ seq_lens = self .seq_lens [:num_reqs ]
764
773
765
774
# Copy the tensors to the GPU.
766
775
self .input_ids [:total_num_scheduled_tokens ].copy_ (
@@ -776,22 +785,6 @@ def _prepare_inputs(
776
785
self .positions_cpu [:total_num_scheduled_tokens ],
777
786
non_blocking = True )
778
787
779
- self .query_start_loc [:num_reqs + 1 ].copy_ (
780
- self .query_start_loc_cpu [:num_reqs + 1 ], non_blocking = True )
781
- self .seq_lens [:num_reqs ].copy_ (self .seq_lens_cpu [:num_reqs ],
782
- non_blocking = True )
783
-
784
- # Fill unused with 0 for full cuda graph mode.
785
- self .seq_lens [num_reqs :].fill_ (0 )
786
- # Note: pad query_start_loc to be non-decreasing, as kernels
787
- # like FlashAttention requires that
788
- self .query_start_loc [num_reqs + 1 :].fill_ (
789
- self .query_start_loc_cpu [num_reqs ].item ())
790
-
791
- query_start_loc = self .query_start_loc [:num_reqs + 1 ]
792
-
793
- spec_decode_common_attn_metadata = None
794
-
795
788
use_spec_decode = len (
796
789
scheduler_output .scheduled_spec_decode_tokens ) > 0
797
790
if not use_spec_decode :
@@ -860,6 +853,13 @@ def _prepare_inputs(
860
853
per_layer_metadata [layer_name ]
861
854
attn_metadata [layer_name ] = encoder_attn_metadata
862
855
856
+ # Used in the below loop.
857
+ query_start_loc_cpu = self .query_start_loc_cpu [:num_reqs + 1 ]
858
+ seq_lens_cpu = self .seq_lens_cpu [:num_reqs ]
859
+ num_computed_tokens_cpu = (
860
+ self .input_batch .num_computed_tokens_cpu_tensor [:num_reqs ])
861
+ spec_decode_common_attn_metadata = None
862
+
863
863
# Prepare the attention metadata for each KV cache group and make layers
864
864
# in the same group share the same metadata.
865
865
for kv_cache_group_id , kv_cache_group_spec in enumerate (
@@ -874,12 +874,11 @@ def _prepare_inputs(
874
874
blk_table .slot_mapping [total_num_scheduled_tokens :].fill_ (- 1 )
875
875
876
876
common_attn_metadata = CommonAttentionMetadata (
877
- query_start_loc = self .query_start_loc [:num_reqs + 1 ],
878
- query_start_loc_cpu = self .query_start_loc_cpu [:num_reqs + 1 ],
879
- seq_lens = self .seq_lens [:num_reqs ],
880
- seq_lens_cpu = self .seq_lens_cpu [:num_reqs ],
881
- num_computed_tokens_cpu = self .input_batch .
882
- num_computed_tokens_cpu_tensor [:num_reqs ],
877
+ query_start_loc = query_start_loc ,
878
+ query_start_loc_cpu = query_start_loc_cpu ,
879
+ seq_lens = seq_lens ,
880
+ seq_lens_cpu = seq_lens_cpu ,
881
+ num_computed_tokens_cpu = num_computed_tokens_cpu ,
883
882
num_reqs = num_reqs ,
884
883
num_actual_tokens = total_num_scheduled_tokens ,
885
884
max_query_len = max_num_scheduled_tokens ,
0 commit comments