@@ -826,7 +826,8 @@ def _prepare_inputs(
826
826
# Prepare encoder attention metadata separately
827
827
# (encoder layers are not in KV cache groups)
828
828
if self .is_encoder_only_model :
829
- common_attn_metadata , encoder_attn_metadata = \
829
+
830
+ per_layer_metadata = \
830
831
self ._build_encoder_only_attn_metadata (
831
832
scheduler_output )
832
833
@@ -835,6 +836,8 @@ def _prepare_inputs(
835
836
self .vllm_config , Attention )
836
837
for layer_name , attn_module in attention_layers .items ():
837
838
if attn_module .attn_type == AttentionType .ENCODER_ONLY :
839
+ common_attn_metadata , encoder_attn_metadata = \
840
+ per_layer_metadata [layer_name ]
838
841
attn_metadata [layer_name ] = encoder_attn_metadata
839
842
840
843
# Prepare the attention metadata for each KV cache group and make layers
@@ -2683,30 +2686,41 @@ def create_attn_groups(
2683
2686
# Check if model is encoder-only
2684
2687
block_size = self .vllm_config .cache_config .block_size
2685
2688
use_mla = self .vllm_config .model_config .use_mla
2686
- attn_specs = list [AttentionSpec ]( )
2687
- for attn_module in attn_layers .values ():
2689
+ attn_specs : dict [ AttentionSpec , list [str ]] = defaultdict ( list )
2690
+ for layer_name , attn_module in attn_layers .items ():
2688
2691
2689
2692
if attn_module .attn_type == AttentionType .ENCODER_ONLY :
2690
- assert attn_module .sliding_window is None , "Sliding "
2691
- "window attention is not supported for encoder-only models"
2692
-
2693
- attn_specs .append (
2694
- FullAttentionSpec (block_size = block_size ,
2695
- num_kv_heads = attn_module .num_kv_heads ,
2696
- head_size = attn_module .head_size ,
2697
- dtype = self .kv_cache_dtype ,
2698
- use_mla = use_mla ))
2693
+ if attn_module .sliding_window is None :
2694
+ attn_spec : AttentionSpec = FullAttentionSpec (
2695
+ block_size = block_size ,
2696
+ num_kv_heads = attn_module .num_kv_heads ,
2697
+ head_size = attn_module .head_size ,
2698
+ dtype = self .kv_cache_dtype ,
2699
+ use_mla = use_mla )
2700
+ else :
2701
+ attn_spec = SlidingWindowSpec (
2702
+ block_size = block_size ,
2703
+ num_kv_heads = attn_module .num_kv_heads ,
2704
+ head_size = attn_module .head_size ,
2705
+ dtype = self .kv_cache_dtype ,
2706
+ sliding_window = attn_module .sliding_window ,
2707
+ use_mla = use_mla )
2708
+ attn_specs [attn_spec ].append (layer_name )
2709
+
2699
2710
else :
2700
2711
raise ValueError ("Expected only encoder-only layers" )
2701
2712
2702
2713
if len (attn_specs ) > 0 :
2703
- assert len ( attn_specs ) == len ( attn_layers ), \
2704
- "All or none of the layers are expected to be encoder-only"
2714
+ total_layers = 0
2715
+ for attn_spec , layer_names in attn_specs . items ():
2705
2716
2706
- attn_backends = get_attn_backends_for_layers (attn_layers .keys ())
2717
+ attn_backends = get_attn_backends_for_layers (layer_names )
2718
+ total_layers += len (layer_names )
2707
2719
2708
- self .attn_groups .append (
2709
- create_attn_groups (attn_backends , attn_specs [0 ]))
2720
+ self .attn_groups .append (
2721
+ create_attn_groups (attn_backends , attn_spec ))
2722
+ assert total_layers == len (attn_layers ), \
2723
+ "All or none of the layers are expected to be encoder-only"
2710
2724
self .is_encoder_only_model = True
2711
2725
2712
2726
def calculate_reorder_batch_threshold (self ) -> None :
@@ -3071,7 +3085,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
3071
3085
3072
3086
def _build_encoder_only_attn_metadata (
3073
3087
self , scheduler_output : "SchedulerOutput" ) -> \
3074
- tuple [CommonAttentionMetadata , Any ]:
3088
+ dict [ str , tuple [CommonAttentionMetadata , Any ] ]:
3075
3089
"""Prepare encoder attention metadata for encoder-only models.
3076
3090
3077
3091
Args:
@@ -3088,33 +3102,45 @@ def _build_encoder_only_attn_metadata(
3088
3102
tokens = [scheduler_output .num_scheduled_tokens [i ] for i in req_ids ]
3089
3103
max_num_scheduled_tokens = max (tokens )
3090
3104
3091
- # Use the first attention metadata builder
3092
- # to create encoder attention metadata
3093
- builder = self .attn_groups [0 ][0 ].metadata_builder
3094
-
3095
3105
dummy_block_table = torch .zeros ((num_reqs , 1 ),
3096
3106
dtype = torch .int32 ,
3097
3107
device = self .device )
3098
3108
dummy_slot_mapping = torch .zeros ((total_num_scheduled_tokens , ),
3099
3109
dtype = torch .int32 ,
3100
3110
device = self .device )
3101
3111
3102
- common_metadata = CommonAttentionMetadata (
3103
- query_start_loc = self .query_start_loc [:num_reqs + 1 ],
3104
- query_start_loc_cpu = self .query_start_loc_cpu [:num_reqs + 1 ],
3105
- seq_lens = self .seq_lens [:num_reqs ],
3106
- seq_lens_cpu = self .seq_lens_cpu [:num_reqs ],
3107
- num_computed_tokens_cpu = self .input_batch .
3108
- num_computed_tokens_cpu_tensor [:num_reqs ],
3109
- num_reqs = num_reqs ,
3110
- num_actual_tokens = total_num_scheduled_tokens ,
3111
- max_query_len = max_num_scheduled_tokens ,
3112
- block_table_tensor = dummy_block_table ,
3113
- slot_mapping = dummy_slot_mapping ,
3114
- causal = False ,
3115
- )
3112
+ group_metadata = dict [str , tuple [CommonAttentionMetadata , Any ]]()
3116
3113
3117
- return common_metadata , builder .build (
3118
- common_prefix_len = 0 , # No cascade for encoder
3119
- common_attn_metadata = common_metadata ,
3120
- )
3114
+ for attn_group_list in self .attn_groups :
3115
+
3116
+ assert len (attn_group_list ) == 1
3117
+ attn_group = attn_group_list [0 ]
3118
+
3119
+ # Use the first attention metadata builder
3120
+ # to create encoder attention metadata
3121
+ builder = attn_group .metadata_builder
3122
+
3123
+ common_metadata = CommonAttentionMetadata (
3124
+ query_start_loc = self .query_start_loc [:num_reqs + 1 ],
3125
+ query_start_loc_cpu = self .query_start_loc_cpu [:num_reqs + 1 ],
3126
+ seq_lens = self .seq_lens [:num_reqs ],
3127
+ seq_lens_cpu = self .seq_lens_cpu [:num_reqs ],
3128
+ num_computed_tokens_cpu = self .input_batch .
3129
+ num_computed_tokens_cpu_tensor [:num_reqs ],
3130
+ num_reqs = num_reqs ,
3131
+ num_actual_tokens = total_num_scheduled_tokens ,
3132
+ max_query_len = max_num_scheduled_tokens ,
3133
+ block_table_tensor = dummy_block_table ,
3134
+ slot_mapping = dummy_slot_mapping ,
3135
+ causal = False ,
3136
+ )
3137
+
3138
+ metadata = builder .build (
3139
+ common_prefix_len = 0 , # No cascade for encoder
3140
+ common_attn_metadata = common_metadata ,
3141
+ )
3142
+
3143
+ for layer_name in attn_group .layer_names :
3144
+ group_metadata [layer_name ] = (common_metadata , metadata )
3145
+
3146
+ return group_metadata
0 commit comments