Conversation
Signed-off-by: Youlei Yang <youlei.yang@intel.com>
Signed-off-by: Youlei Yang <youlei.yang@intel.com>
|
Porting #830 |
There was a problem hiding this comment.
Pull request overview
This PR fixes a bug where preempted prompts were incorrectly classified as decode requests, causing runtime errors. The fix introduces a dedicated method to properly identify prompt vs decode requests by considering scheduled tokens alongside computed tokens.
Changes:
- Added
_is_prompt()method to correctly identify prompt requests, including preempted ones - Refactored prompt/decode classification logic to use the new method
- Added preemption handling test to CI pipeline
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
| vllm_gaudi/v1/worker/hpu_model_runner.py | Implements _is_prompt() method and refactors prompt/decode classification logic to fix preemption handling |
| tests/full_tests/preemption.py | Adds new test script to verify preemption handling works correctly |
| tests/full_tests/ci_gsm8k_tests.sh | Integrates preemption test into CI pipeline |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def _is_prompt(self, i: int, scheduler_output: "SchedulerOutput") -> bool: | ||
| req_id = self.input_batch.req_ids[i] | ||
| num_computed_tokens = int(self.input_batch.num_computed_tokens_cpu[i]) | ||
| num_prompt_tokens = int(self.input_batch.num_prompt_tokens[i]) | ||
| num_scheduled_tokens = scheduler_output.num_scheduled_tokens.get(req_id) | ||
| spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens.get(req_id) |
There was a problem hiding this comment.
The method could fail without a helpful error message if req_id is None or if dictionary lookups return None. Consider adding validation and raising descriptive errors when required values are missing.
| if self._is_prompt(i, scheduler_output): | ||
| break | ||
|
|
||
| # This is decode |
There was a problem hiding this comment.
This is decode
why remove this comment?
There was a problem hiding this comment.
It's a comment for the older statement of assert num_scheduled_tokens == 1 which is not relevant to the current codes anymore.
| for layer in model.language_model.model.layers: | ||
| if "ChunkedLocalAttention" in layer.self_attn.attn.get_attn_backend().__name__: | ||
| layer.self_attn.attn.impl.is_chunked_attention = True | ||
| except Exception: |
There was a problem hiding this comment.
when there will be an exception?
pass an exception with any warning is dangerous in most cases. suggest to either add some warning message or make sure there is no execption.
There was a problem hiding this comment.
Oh, sorry. It's redundant codes from cherry-picking which is not relevant to this PR.
Fixed.
| except Exception: | ||
| pass | ||
|
|
||
| def _is_prompt(self, i: int, scheduler_output: "SchedulerOutput") -> bool: |
There was a problem hiding this comment.
suggest to rename i to a more meaningful name in the funtion.
Signed-off-by: Youlei Yang <youlei.yang@intel.com>
Signed-off-by: Youlei Yang <youlei.yang@intel.com>
### Motivation The preempted prompts might failed to mitch the `num_computed_tokens < num_prompt_tokens` test and be treated as decoding then cause runtime error. ### Changes - add `_is_prompt()` to check if a request is prompt or not. - consider the `num_scheduled_tokens` to handle the preempted prompts. - add test for preemption handling to the CI. --------- Signed-off-by: Youlei Yang <youlei.yang@intel.com>
The preempted prompts might failed to mitch the `num_computed_tokens < num_prompt_tokens` test and be treated as decoding then cause runtime error. - add `_is_prompt()` to check if a request is prompt or not. - consider the `num_scheduled_tokens` to handle the preempted prompts. - add test for preemption handling to the CI. --------- Signed-off-by: Youlei Yang <youlei.yang@intel.com>
Motivation
The preempted prompts might failed to mitch the
num_computed_tokens < num_prompt_tokenstest and be treated as decoding then cause runtime error.Changes
_is_prompt()to check if a request is prompt or not.num_scheduled_tokensto handle the preempted prompts.