Skip to content

Commit 4ada515

Browse files
committed
fix dummy run
Signed-off-by: Wuxun Zhang <[email protected]>
1 parent 06a1851 commit 4ada515

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

vllm_gaudi/v1/worker/hpu_model_runner.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,9 @@ def forward(self, *args, **kwargs):
447447
kwargs.update(model_mm_kwargs)
448448

449449
num_input_tokens = input_ids.size(0) * input_ids.size(1)
450-
with set_forward_context(attn_meta, self.vllm_config, num_tokens=num_input_tokens):
450+
with set_forward_context(attn_meta,
451+
self.vllm_config,
452+
num_tokens=num_input_tokens):
451453
hidden_states = self.model(*args, **kwargs)
452454
if self._rotary_prepare_cos_sin is not None:
453455
self._reset_rotary_cos_sin()
@@ -2744,13 +2746,14 @@ def profile_run(self) -> None:
27442746
max_seq_len = math.ceil(
27452747
(self.max_num_tokens // self.max_prefill_batch_size) /
27462748
self.block_size) * self.block_size
2749+
max_seq_len = min(max_seq_len, self.max_model_len)
27472750
self._execute_dummy_scenario(
27482751
(self.max_prefill_batch_size, max_seq_len, 0), None)
27492752

27502753
def _dummy_run(self, max_num_batched_tokens: int) -> None:
27512754
assert max_num_batched_tokens == 1
27522755
prompt_cfg = None
2753-
decode_cfg = 1, 1
2756+
decode_cfg = 1, 1, 1
27542757
# add dummy decode run
27552758
self._execute_dummy_scenario(prompt_cfg, decode_cfg)
27562759
return

0 commit comments

Comments
 (0)