@@ -1448,6 +1448,20 @@ def _get_num_decodes(self) -> int:
14481448 num_decodes += 1
14491449 return num_decodes
14501450
1451+ def _is_prompt (self , i : int , scheduler_output : "SchedulerOutput" ) -> bool :
1452+ req_id = self .input_batch .req_ids [i ]
1453+ num_computed_tokens = int (self .input_batch .num_computed_tokens_cpu [i ])
1454+ num_prompt_tokens = int (self .input_batch .num_prompt_tokens [i ])
1455+ num_scheduled_tokens = scheduler_output .num_scheduled_tokens .get (req_id )
1456+ spec_decode_tokens = scheduler_output .scheduled_spec_decode_tokens .get (req_id )
1457+
1458+ num_decode_tokens = 1 if spec_decode_tokens is None else len (spec_decode_tokens ) + 1
1459+ is_prompt = num_computed_tokens < num_prompt_tokens # normal prompt
1460+ is_prompt = is_prompt or num_scheduled_tokens > num_decode_tokens # maybe preempted prompt
1461+ is_prompt = is_prompt and not self .is_decoder_only (req_id )
1462+
1463+ return is_prompt
1464+
14511465 def _get_prompts_and_decodes (
14521466 self ,
14531467 scheduler_output : "SchedulerOutput" ,
@@ -1471,7 +1485,6 @@ def _get_prompts_and_decodes(
14711485 requests = scheduler_output .kv_connector_metadata .requests
14721486 else :
14731487 requests = None
1474-
14751488 # Traverse decodes first
14761489 decode_req_ids = []
14771490 num_computed_tokens_decode = []
@@ -1485,19 +1498,11 @@ def _get_prompts_and_decodes(
14851498 self .input_batch .req_type [req_id ] = requests_type [req_id ]
14861499 break
14871500
1488- num_computed_tokens = self .input_batch .num_computed_tokens_cpu [i ]
1489- num_prompt_tokens = self .input_batch .num_prompt_tokens [i ]
1490- num_scheduled_tokens = scheduler_output .num_scheduled_tokens [req_id ]
1491- if num_computed_tokens < num_prompt_tokens and \
1492- not self .is_decoder_only (req_id ):
1493- # This is prompt
1501+ if self ._is_prompt (i , scheduler_output ):
14941502 break
14951503
1496- # This is decode
1497- # NOTE(chendi): To support spec decode,
1498- # we don't assume num_scheduled_tokens == 1.
1499-
15001504 decode_req_ids .append (req_id )
1505+ num_computed_tokens = self .input_batch .num_computed_tokens_cpu [i ]
15011506 num_computed_tokens_decode .append (int (num_computed_tokens + 1 ))
15021507
15031508 if self .profiler .enabled :
@@ -1510,16 +1515,12 @@ def _get_prompts_and_decodes(
15101515 req_id = self .input_batch .req_ids [i ]
15111516 assert req_id is not None
15121517
1513- num_computed_tokens = self .input_batch .num_computed_tokens_cpu [i ]
1514- num_prompt_tokens = self .input_batch .num_prompt_tokens [i ]
1515- num_scheduled_tokens = scheduler_output .num_scheduled_tokens [req_id ]
1516-
15171518 # Must be prompt
1518- assert num_computed_tokens < num_prompt_tokens
1519- # NOTE(kzawora): In preempted sequences, num_output_tokens can be > 0, and still be a valid prefill
1519+ assert self ._is_prompt (i , scheduler_output )
15201520
15211521 prompt_req_ids .append (req_id )
1522- prompt_scheduled_tokens .append (num_scheduled_tokens )
1522+ num_scheduled_tokens = scheduler_output .num_scheduled_tokens [req_id ]
1523+ prompt_scheduled_tokens .append (int (num_scheduled_tokens ))
15231524
15241525 return PromptDecodeInfo (prompt_req_ids , decode_req_ids , prompt_scheduled_tokens )
15251526
0 commit comments