41
41
is_pin_memory_available , LazyLoader )
42
42
from vllm_gaudi .utils import HPUCompileConfig , is_fake_hpu
43
43
from vllm_gaudi .v1 .attention .backends .hpu_attn import HPUAttentionMetadataV1
44
+ from vllm .v1 .attention .backends .utils import CommonAttentionMetadata
44
45
from vllm .v1 .kv_cache_interface import (FullAttentionSpec , KVCacheConfig ,
45
46
KVCacheSpec )
46
47
from vllm .v1 .pool .metadata import PoolingMetadata
@@ -1469,7 +1470,6 @@ def _execute_model_generic(self,
1469
1470
f"graphs{ 'T' if use_graphs else 'F' } " )
1470
1471
else :
1471
1472
model_event_name = 'model_executable'
1472
- # with set_forward_context(attn_metadata, self.vllm_config):
1473
1473
with self .profiler .record_event ('internal' , model_event_name ):
1474
1474
hidden_states = self .model .forward (
1475
1475
input_ids = token_ids ,
@@ -1648,8 +1648,6 @@ def _pool(
1648
1648
hidden_states : torch .Tensor ,
1649
1649
num_scheduled_tokens : int ,
1650
1650
num_scheduled_tokens_np : np .ndarray ,
1651
- finished_sending : Optional [set [str ]],
1652
- finished_recving : Optional [set [str ]],
1653
1651
) -> ModelRunnerOutput :
1654
1652
assert self .input_batch .num_reqs == \
1655
1653
len (self .input_batch .pooling_params ), \
@@ -1691,7 +1689,7 @@ def _pool(
1691
1689
pooler_output .append (None )
1692
1690
1693
1691
return ModelRunnerOutput (
1694
- req_ids = self .input_batch .req_ids ,
1692
+ req_ids = [ self .input_batch .req_ids ] ,
1695
1693
req_id_to_index = self .input_batch .req_id_to_index ,
1696
1694
sampled_token_ids = [],
1697
1695
logprobs = None ,
@@ -1795,7 +1793,7 @@ def execute_model(
1795
1793
assert req_id is not None
1796
1794
seq_num_scheduled = scheduler_output .num_scheduled_tokens [req_id ]
1797
1795
num_scheduled_tokens .append (seq_num_scheduled )
1798
- scheduled_req = scheduler_output .scheduled_new_reqs [int ( req_id ) ]
1796
+ scheduled_req = scheduler_output .scheduled_new_reqs [idx ]
1799
1797
token_ids = scheduled_req .prompt_token_ids
1800
1798
# Convert to torch tensor if not already
1801
1799
if not isinstance (token_ids , torch .Tensor ):
@@ -1807,21 +1805,33 @@ def execute_model(
1807
1805
# Step 2: concatenate input_ids across requests
1808
1806
input_ids = torch .cat (input_ids_list , dim = 0 ).to (self .device )
1809
1807
1810
- # Step 3: build req_indices: [0,0,...,1,1,...,2,2,...] for each token
1811
- req_indices = np .repeat (np .arange (num_reqs , dtype = np .int64 ), num_scheduled_tokens )
1808
+ # Step 3: add prefix lengths to get absolute positions
1809
+ absolute_positions = []
1810
+ for i , n in enumerate (num_scheduled_tokens ):
1811
+ prefix = num_computed_tokens_cpu [i ]
1812
+ absolute_positions .append (prefix + np .arange (n , dtype = np .int64 ))
1812
1813
1813
- # Step 4: arange: [0, 1, 2, ..., total_tokens-1]
1814
- total_tokens = sum (num_scheduled_tokens )
1815
- token_offset_within_request = np .arange (total_tokens , dtype = np .int64 )
1816
-
1817
- # Step 5: add prefix lengths to get absolute positions
1818
- absolute_positions_np = num_computed_tokens_cpu [req_indices ] + token_offset_within_request
1819
-
1820
- # Step 6: convert to torch tensor
1814
+ absolute_positions_np = np .concatenate (absolute_positions )
1815
+
1816
+ # Step 7: convert to torch tensor
1821
1817
position_ids = torch .from_numpy (absolute_positions_np ).to (self .device )
1822
1818
1823
1819
# (Optional) Attention mask
1824
1820
attn_metadata = None
1821
+ # attn_metadata: dict[str, Any] = {}
1822
+
1823
+ # per_layer_metadata = \
1824
+ # self._build_encoder_only_attn_metadata(
1825
+ # scheduler_output)
1826
+
1827
+ # # Add encoder attention metadata for all encoder layers
1828
+ # attention_layers = get_layers_from_vllm_config(
1829
+ # self.vllm_config, Attention)
1830
+ # for layer_name, attn_module in attention_layers.items():
1831
+ # if attn_module.attn_type == AttentionType.ENCODER_ONLY:
1832
+ # common_attn_metadata, encoder_attn_metadata =\
1833
+ # per_layer_metadata[layer_name]
1834
+ # attn_metadata[layer_name] = encoder_attn_metadata
1825
1835
1826
1836
# Forward pass
1827
1837
with set_forward_context (
@@ -1833,15 +1843,17 @@ def execute_model(
1833
1843
)
1834
1844
1835
1845
# Pool the hidden states and return
1836
- num_scheduled_tokens_np = self .input_batch .num_tokens .copy ()
1846
+ if isinstance (num_scheduled_tokens , list ):
1847
+ num_scheduled_tokens_np = np .array (num_scheduled_tokens , dtype = np .int32 )
1848
+ else :
1849
+ num_scheduled_tokens_np = num_scheduled_tokens
1850
+ num_scheduled_tokens = int (num_scheduled_tokens_np .sum ())
1837
1851
1838
1852
# Pool the hidden states
1839
1853
return self ._pool (
1840
1854
hidden_states ,
1841
- num_scheduled_tokens = int ( token_ids . size ( 0 )) ,
1855
+ num_scheduled_tokens = num_scheduled_tokens ,
1842
1856
num_scheduled_tokens_np = num_scheduled_tokens_np ,
1843
- finished_sending = None ,
1844
- finished_recving = None ,
1845
1857
)
1846
1858
# If necessary, swap decodes/prompts to have all decodes on the start
1847
1859
ensure_decodes_first (self .input_batch )
@@ -1853,7 +1865,8 @@ def execute_model(
1853
1865
with self .profiler .record_event ('internal' , 'prepare_input_tensors' ):
1854
1866
prefill_data , decode_data = self ._prepare_inputs (
1855
1867
scheduler_output , num_prefills , num_decodes )
1856
- #FIXME(kzawora): Currently there's no handling of logprobs. Fix that
1868
+
1869
+ # #FIXME(kzawora): Currently there's no handling of logprobs. Fix that
1857
1870
# later.
1858
1871
prefill_sampled_token_ids = []
1859
1872
prefill_sampled_requests = []
@@ -2912,3 +2925,66 @@ def reload_weights(self) -> None:
2912
2925
logger .info ("Reloading weights inplace..." )
2913
2926
model_loader .load_weights (self .model , model_config = self .model_config )
2914
2927
torch .hpu .synchronize ()
2928
+
2929
+ def _build_encoder_only_attn_metadata (
2930
+ self , scheduler_output : "SchedulerOutput" ) -> \
2931
+ dict [str , tuple [CommonAttentionMetadata , Any ]]:
2932
+ """Prepare encoder attention metadata for encoder-only models.
2933
+
2934
+ Args:
2935
+ scheduler_output: Scheduler output
2936
+
2937
+ Returns:
2938
+ dict[str, Any]: Encoder attention metadata
2939
+ """
2940
+ num_reqs = self .input_batch .num_reqs
2941
+ total_num_scheduled_tokens = scheduler_output .total_num_scheduled_tokens
2942
+
2943
+ # Get the number of scheduled tokens for each request.
2944
+ req_ids = self .input_batch .req_ids
2945
+ tokens = [scheduler_output .num_scheduled_tokens [i ] for i in req_ids ]
2946
+ max_num_scheduled_tokens = max (tokens )
2947
+
2948
+ dummy_block_table = torch .zeros ((num_reqs , 1 ),
2949
+ dtype = torch .int32 ,
2950
+ device = self .device )
2951
+ dummy_slot_mapping = torch .zeros ((total_num_scheduled_tokens , ),
2952
+ dtype = torch .int32 ,
2953
+ device = self .device )
2954
+
2955
+ group_metadata = dict [str , tuple [CommonAttentionMetadata , Any ]]()
2956
+
2957
+ for attn_group_list in self .attn_groups :
2958
+
2959
+ assert len (attn_group_list ) == 1
2960
+ attn_group = attn_group_list [0 ]
2961
+
2962
+ # Use the first attention metadata builder
2963
+ # to create encoder attention metadata
2964
+ builder = attn_group .metadata_builder
2965
+
2966
+ common_metadata = CommonAttentionMetadata (
2967
+ query_start_loc = self .query_start_loc [:num_reqs + 1 ],
2968
+ query_start_loc_cpu = self .query_start_loc_cpu [:num_reqs + 1 ],
2969
+ seq_lens = self .seq_lens [:num_reqs ],
2970
+ seq_lens_cpu = self .seq_lens_cpu [:num_reqs ],
2971
+ num_computed_tokens_cpu = self .input_batch .
2972
+ num_computed_tokens_cpu_tensor [:num_reqs ],
2973
+ num_reqs = num_reqs ,
2974
+ num_actual_tokens = total_num_scheduled_tokens ,
2975
+ max_query_len = max_num_scheduled_tokens ,
2976
+ max_seq_len = self .seq_lens_cpu [:num_reqs ].max ().item (),
2977
+ block_table_tensor = dummy_block_table ,
2978
+ slot_mapping = dummy_slot_mapping ,
2979
+ causal = False ,
2980
+ )
2981
+
2982
+ metadata = builder .build (
2983
+ common_prefix_len = 0 , # No cascade for encoder
2984
+ common_attn_metadata = common_metadata ,
2985
+ )
2986
+
2987
+ for layer_name in attn_group .layer_names :
2988
+ group_metadata [layer_name ] = (common_metadata , metadata )
2989
+
2990
+ return group_metadata
0 commit comments