Skip to content

Commit ea44413

Browse files
committed
fix ci error
Signed-off-by: Wuxun Zhang <[email protected]>
1 parent 213f54b commit ea44413

File tree

1 file changed

+22
-26
lines changed

1 file changed

+22
-26
lines changed

vllm_gaudi/v1/worker/hpu_model_runner.py

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1633,13 +1633,12 @@ def _prepare_prefill_inputs(
16331633
return all_batches[0], dummy_prefill_input_batches[
16341634
0] if dummy_prefill_input_batches else None
16351635

1636-
def _create_decode_input_data(
1637-
self,
1638-
num_decodes,
1639-
num_scheduled_tokens,
1640-
context_lens,
1641-
block_table_cpu_tensor,
1642-
scheduler_output=None) -> tuple[DecodeInputData, int]:
1636+
def _create_decode_input_data(self,
1637+
num_decodes,
1638+
num_scheduled_tokens,
1639+
context_lens,
1640+
block_table_cpu_tensor,
1641+
scheduler_output=None) -> DecodeInputData:
16431642
# NOTE(kzawora): the +1 is what causes this entire thing to work,
16441643
# as in the paged attention, we don't fetch just the context from cache,
16451644
# but also kvs for the current token
@@ -1652,8 +1651,7 @@ def _create_decode_input_data(
16521651
num_decodes, sum(num_blocks))[0]
16531652

16541653
# dp aware padding
1655-
num_pad_across_dp = self.get_dp_padding(padded_batch_size)
1656-
padded_batch_size += num_pad_across_dp
1654+
padded_batch_size += self.get_dp_padding(padded_batch_size)
16571655

16581656
num_tokens_per_req = num_scheduled_tokens[:num_decodes]
16591657
num_tokens = max(num_tokens_per_req)
@@ -1843,7 +1841,7 @@ def _create_decode_input_data(
18431841
block_size=self.block_size,
18441842
query_start_loc=query_start_loc,
18451843
),
1846-
spec_decode_metadata=spec_decode_metadata), num_pad_across_dp
1844+
spec_decode_metadata=spec_decode_metadata)
18471845

18481846
def _prepare_decode_inputs(
18491847
self,
@@ -1868,7 +1866,8 @@ def _prepare_decode_inputs(
18681866
return self._create_decode_input_data(
18691867
num_decodes, num_scheduled_tokens,
18701868
self.input_batch.num_computed_tokens_cpu[:num_decodes],
1871-
self.input_batch.block_table[0].get_cpu_tensor(), scheduler_output)
1869+
self.input_batch.block_table[0].get_cpu_tensor(),
1870+
scheduler_output), None
18721871

18731872
def _create_dummy_decode_input_data(self) -> DecodeInputData:
18741873
# create dummy decode input data with batch size 1
@@ -1877,13 +1876,10 @@ def _create_dummy_decode_input_data(self) -> DecodeInputData:
18771876
context_lens = [128]
18781877
block_table_cpu_tensor = torch.zeros([self._PAD_BLOCK_ID],
18791878
dtype=torch.int32).reshape(1, -1)
1880-
# num_computed_tokens_cpu = np.array([128], dtype=np.int32)
1881-
# token_ids = np.array(list(int(i) for i in range(context_lens[0])))
1882-
18831879
return self._create_decode_input_data(num_dummy_decodes,
18841880
num_dummy_scheduled_tokens,
18851881
context_lens,
1886-
block_table_cpu_tensor)[0]
1882+
block_table_cpu_tensor)
18871883

18881884
def _get_cumsum_and_arange(
18891885
self,
@@ -2570,17 +2566,6 @@ def execute_model(
25702566
prompt_batch_idx=None,
25712567
is_prompt=False)
25722568
self.profiler.record_counter(self.event_start, counters)
2573-
else:
2574-
if dummy_decode_input_data_across_dp is not None:
2575-
htorch.core.mark_step()
2576-
_, _, dummy_logits_device = self._execute_model_generic(
2577-
dummy_decode_input_data_across_dp.token_ids,
2578-
dummy_decode_input_data_across_dp.position_ids,
2579-
dummy_decode_input_data_across_dp.attn_metadata,
2580-
dummy_decode_input_data_across_dp.logits_indices,
2581-
self.kv_caches,
2582-
warmup_mode=warmup_mode)
2583-
htorch.core.mark_step()
25842569

25852570
################## Spec Decode ##################
25862571
# work on spec decode if max_gen_len > 1
@@ -2617,6 +2602,17 @@ def execute_model(
26172602
spec_decode_metadata, spec_decode_common_attn_metadata,
26182603
decode_data)[:num_decodes]
26192604
################## Spec Decode end ##################
2605+
else:
2606+
if dummy_decode_input_data_across_dp is not None:
2607+
htorch.core.mark_step()
2608+
_, _, dummy_logits_device = self._execute_model_generic(
2609+
dummy_decode_input_data_across_dp.token_ids,
2610+
dummy_decode_input_data_across_dp.position_ids,
2611+
dummy_decode_input_data_across_dp.attn_metadata,
2612+
dummy_decode_input_data_across_dp.logits_indices,
2613+
self.kv_caches,
2614+
warmup_mode=warmup_mode)
2615+
htorch.core.mark_step()
26202616

26212617
if structured_output:
26222618
# Scheduler places cached before prompt

0 commit comments

Comments
 (0)