Skip to content

Commit db3049a

Browse files
committed
Fixed to Process all reqs
Signed-off-by: slokesha <[email protected]>
1 parent 893522a commit db3049a

File tree

1 file changed

+96
-20
lines changed

1 file changed

+96
-20
lines changed

vllm_gaudi/v1/worker/hpu_model_runner.py

Lines changed: 96 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
is_pin_memory_available, LazyLoader)
4242
from vllm_gaudi.utils import HPUCompileConfig, is_fake_hpu
4343
from vllm_gaudi.v1.attention.backends.hpu_attn import HPUAttentionMetadataV1
44+
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
4445
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
4546
KVCacheSpec)
4647
from vllm.v1.pool.metadata import PoolingMetadata
@@ -1469,7 +1470,6 @@ def _execute_model_generic(self,
14691470
f"graphs{'T' if use_graphs else 'F'}")
14701471
else:
14711472
model_event_name = 'model_executable'
1472-
# with set_forward_context(attn_metadata, self.vllm_config):
14731473
with self.profiler.record_event('internal', model_event_name):
14741474
hidden_states = self.model.forward(
14751475
input_ids=token_ids,
@@ -1648,8 +1648,6 @@ def _pool(
16481648
hidden_states: torch.Tensor,
16491649
num_scheduled_tokens: int,
16501650
num_scheduled_tokens_np: np.ndarray,
1651-
finished_sending: Optional[set[str]],
1652-
finished_recving: Optional[set[str]],
16531651
) -> ModelRunnerOutput:
16541652
assert self.input_batch.num_reqs ==\
16551653
len(self.input_batch.pooling_params), \
@@ -1691,7 +1689,7 @@ def _pool(
16911689
pooler_output.append(None)
16921690

16931691
return ModelRunnerOutput(
1694-
req_ids=self.input_batch.req_ids,
1692+
req_ids=[self.input_batch.req_ids],
16951693
req_id_to_index=self.input_batch.req_id_to_index,
16961694
sampled_token_ids=[],
16971695
logprobs=None,
@@ -1795,7 +1793,7 @@ def execute_model(
17951793
assert req_id is not None
17961794
seq_num_scheduled = scheduler_output.num_scheduled_tokens[req_id]
17971795
num_scheduled_tokens.append(seq_num_scheduled)
1798-
scheduled_req = scheduler_output.scheduled_new_reqs[int(req_id)]
1796+
scheduled_req = scheduler_output.scheduled_new_reqs[idx]
17991797
token_ids = scheduled_req.prompt_token_ids
18001798
# Convert to torch tensor if not already
18011799
if not isinstance(token_ids, torch.Tensor):
@@ -1807,21 +1805,33 @@ def execute_model(
18071805
# Step 2: concatenate input_ids across requests
18081806
input_ids = torch.cat(input_ids_list, dim=0).to(self.device)
18091807

1810-
# Step 3: build req_indices: [0,0,...,1,1,...,2,2,...] for each token
1811-
req_indices = np.repeat(np.arange(num_reqs, dtype=np.int64), num_scheduled_tokens)
1808+
# Step 3: add prefix lengths to get absolute positions
1809+
absolute_positions = []
1810+
for i, n in enumerate(num_scheduled_tokens):
1811+
prefix = num_computed_tokens_cpu[i]
1812+
absolute_positions.append(prefix + np.arange(n, dtype=np.int64))
18121813

1813-
# Step 4: arange: [0, 1, 2, ..., total_tokens-1]
1814-
total_tokens = sum(num_scheduled_tokens)
1815-
token_offset_within_request = np.arange(total_tokens, dtype=np.int64)
1816-
1817-
# Step 5: add prefix lengths to get absolute positions
1818-
absolute_positions_np = num_computed_tokens_cpu[req_indices] + token_offset_within_request
1819-
1820-
# Step 6: convert to torch tensor
1814+
absolute_positions_np = np.concatenate(absolute_positions)
1815+
1816+
# Step 7: convert to torch tensor
18211817
position_ids = torch.from_numpy(absolute_positions_np).to(self.device)
18221818

18231819
# (Optional) Attention mask
18241820
attn_metadata = None
1821+
# attn_metadata: dict[str, Any] = {}
1822+
1823+
# per_layer_metadata = \
1824+
# self._build_encoder_only_attn_metadata(
1825+
# scheduler_output)
1826+
1827+
# # Add encoder attention metadata for all encoder layers
1828+
# attention_layers = get_layers_from_vllm_config(
1829+
# self.vllm_config, Attention)
1830+
# for layer_name, attn_module in attention_layers.items():
1831+
# if attn_module.attn_type == AttentionType.ENCODER_ONLY:
1832+
# common_attn_metadata, encoder_attn_metadata =\
1833+
# per_layer_metadata[layer_name]
1834+
# attn_metadata[layer_name] = encoder_attn_metadata
18251835

18261836
# Forward pass
18271837
with set_forward_context(
@@ -1833,15 +1843,17 @@ def execute_model(
18331843
)
18341844

18351845
# Pool the hidden states and return
1836-
num_scheduled_tokens_np = self.input_batch.num_tokens.copy()
1846+
if isinstance(num_scheduled_tokens, list):
1847+
num_scheduled_tokens_np = np.array(num_scheduled_tokens, dtype=np.int32)
1848+
else:
1849+
num_scheduled_tokens_np = num_scheduled_tokens
1850+
num_scheduled_tokens = int(num_scheduled_tokens_np.sum())
18371851

18381852
# Pool the hidden states
18391853
return self._pool(
18401854
hidden_states,
1841-
num_scheduled_tokens=int(token_ids.size(0)),
1855+
num_scheduled_tokens=num_scheduled_tokens,
18421856
num_scheduled_tokens_np=num_scheduled_tokens_np,
1843-
finished_sending=None,
1844-
finished_recving=None,
18451857
)
18461858
# If necessary, swap decodes/prompts to have all decodes on the start
18471859
ensure_decodes_first(self.input_batch)
@@ -1853,7 +1865,8 @@ def execute_model(
18531865
with self.profiler.record_event('internal', 'prepare_input_tensors'):
18541866
prefill_data, decode_data = self._prepare_inputs(
18551867
scheduler_output, num_prefills, num_decodes)
1856-
#FIXME(kzawora): Currently there's no handling of logprobs. Fix that
1868+
1869+
# #FIXME(kzawora): Currently there's no handling of logprobs. Fix that
18571870
# later.
18581871
prefill_sampled_token_ids = []
18591872
prefill_sampled_requests = []
@@ -2912,3 +2925,66 @@ def reload_weights(self) -> None:
29122925
logger.info("Reloading weights inplace...")
29132926
model_loader.load_weights(self.model, model_config=self.model_config)
29142927
torch.hpu.synchronize()
2928+
2929+
def _build_encoder_only_attn_metadata(
2930+
self, scheduler_output: "SchedulerOutput") -> \
2931+
dict[str, tuple[CommonAttentionMetadata, Any]]:
2932+
"""Prepare encoder attention metadata for encoder-only models.
2933+
2934+
Args:
2935+
scheduler_output: Scheduler output
2936+
2937+
Returns:
2938+
dict[str, Any]: Encoder attention metadata
2939+
"""
2940+
num_reqs = self.input_batch.num_reqs
2941+
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
2942+
2943+
# Get the number of scheduled tokens for each request.
2944+
req_ids = self.input_batch.req_ids
2945+
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
2946+
max_num_scheduled_tokens = max(tokens)
2947+
2948+
dummy_block_table = torch.zeros((num_reqs, 1),
2949+
dtype=torch.int32,
2950+
device=self.device)
2951+
dummy_slot_mapping = torch.zeros((total_num_scheduled_tokens, ),
2952+
dtype=torch.int32,
2953+
device=self.device)
2954+
2955+
group_metadata = dict[str, tuple[CommonAttentionMetadata, Any]]()
2956+
2957+
for attn_group_list in self.attn_groups:
2958+
2959+
assert len(attn_group_list) == 1
2960+
attn_group = attn_group_list[0]
2961+
2962+
# Use the first attention metadata builder
2963+
# to create encoder attention metadata
2964+
builder = attn_group.metadata_builder
2965+
2966+
common_metadata = CommonAttentionMetadata(
2967+
query_start_loc=self.query_start_loc[:num_reqs + 1],
2968+
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
2969+
seq_lens=self.seq_lens[:num_reqs],
2970+
seq_lens_cpu=self.seq_lens_cpu[:num_reqs],
2971+
num_computed_tokens_cpu=self.input_batch.
2972+
num_computed_tokens_cpu_tensor[:num_reqs],
2973+
num_reqs=num_reqs,
2974+
num_actual_tokens=total_num_scheduled_tokens,
2975+
max_query_len=max_num_scheduled_tokens,
2976+
max_seq_len=self.seq_lens_cpu[:num_reqs].max().item(),
2977+
block_table_tensor=dummy_block_table,
2978+
slot_mapping=dummy_slot_mapping,
2979+
causal=False,
2980+
)
2981+
2982+
metadata = builder.build(
2983+
common_prefix_len=0, # No cascade for encoder
2984+
common_attn_metadata=common_metadata,
2985+
)
2986+
2987+
for layer_name in attn_group.layer_names:
2988+
group_metadata[layer_name] = (common_metadata, metadata)
2989+
2990+
return group_metadata

0 commit comments

Comments
 (0)