Skip to content

Commit 7b9b62d

Browse files
committed
fix preempted prompts
Signed-off-by: Youlei Yang <youlei.yang@intel.com>
1 parent 9ce14a2 commit 7b9b62d

File tree

1 file changed

+19
-18
lines changed

1 file changed

+19
-18
lines changed

vllm_gaudi/v1/worker/hpu_model_runner.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1448,6 +1448,20 @@ def _get_num_decodes(self) -> int:
14481448
num_decodes += 1
14491449
return num_decodes
14501450

1451+
def _is_prompt(self, i: int, scheduler_output: "SchedulerOutput") -> bool:
1452+
req_id = self.input_batch.req_ids[i]
1453+
num_computed_tokens = int(self.input_batch.num_computed_tokens_cpu[i])
1454+
num_prompt_tokens = int(self.input_batch.num_prompt_tokens[i])
1455+
num_scheduled_tokens = scheduler_output.num_scheduled_tokens.get(req_id)
1456+
spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens.get(req_id)
1457+
1458+
num_decode_tokens = 1 if spec_decode_tokens is None else len(spec_decode_tokens) + 1
1459+
is_prompt = num_computed_tokens < num_prompt_tokens # normal prompt
1460+
is_prompt = is_prompt or num_scheduled_tokens > num_decode_tokens # maybe preempted prompt
1461+
is_prompt = is_prompt and not self.is_decoder_only(req_id)
1462+
1463+
return is_prompt
1464+
14511465
def _get_prompts_and_decodes(
14521466
self,
14531467
scheduler_output: "SchedulerOutput",
@@ -1471,7 +1485,6 @@ def _get_prompts_and_decodes(
14711485
requests = scheduler_output.kv_connector_metadata.requests
14721486
else:
14731487
requests = None
1474-
14751488
# Traverse decodes first
14761489
decode_req_ids = []
14771490
num_computed_tokens_decode = []
@@ -1485,19 +1498,11 @@ def _get_prompts_and_decodes(
14851498
self.input_batch.req_type[req_id] = requests_type[req_id]
14861499
break
14871500

1488-
num_computed_tokens = self.input_batch.num_computed_tokens_cpu[i]
1489-
num_prompt_tokens = self.input_batch.num_prompt_tokens[i]
1490-
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
1491-
if num_computed_tokens < num_prompt_tokens and \
1492-
not self.is_decoder_only(req_id):
1493-
# This is prompt
1501+
if self._is_prompt(i, scheduler_output):
14941502
break
14951503

1496-
# This is decode
1497-
# NOTE(chendi): To support spec decode,
1498-
# we don't assume num_scheduled_tokens == 1.
1499-
15001504
decode_req_ids.append(req_id)
1505+
num_computed_tokens = self.input_batch.num_computed_tokens_cpu[i]
15011506
num_computed_tokens_decode.append(int(num_computed_tokens + 1))
15021507

15031508
if self.profiler.enabled:
@@ -1510,16 +1515,12 @@ def _get_prompts_and_decodes(
15101515
req_id = self.input_batch.req_ids[i]
15111516
assert req_id is not None
15121517

1513-
num_computed_tokens = self.input_batch.num_computed_tokens_cpu[i]
1514-
num_prompt_tokens = self.input_batch.num_prompt_tokens[i]
1515-
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
1516-
15171518
# Must be prompt
1518-
assert num_computed_tokens < num_prompt_tokens
1519-
# NOTE(kzawora): In preempted sequences, num_output_tokens can be > 0, and still be a valid prefill
1519+
assert self._is_prompt(i, scheduler_output)
15201520

15211521
prompt_req_ids.append(req_id)
1522-
prompt_scheduled_tokens.append(num_scheduled_tokens)
1522+
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
1523+
prompt_scheduled_tokens.append(int(num_scheduled_tokens))
15231524

15241525
return PromptDecodeInfo(prompt_req_ids, decode_req_ids, prompt_scheduled_tokens)
15251526

0 commit comments

Comments
 (0)