-
-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[Attention] Full CG support for llama4 and remove use of deprecated properties #31851
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Attention] Full CG support for llama4 and remove use of deprecated properties #31851
Conversation
- Moved virtual batch computation from numpy (CPU) to Triton kernel (GPU) - Added persistent buffers for virtual batch metadata in ChunkedLocalAttentionBuilder - Changed CG support level from NEVER to UNIFORM_BATCH - Removed make_local_attention_virtual_batches function from utils.py - Updated test to use the new builder-based approach - Added buffer zeroing before kernel to handle masked writes Still in progress: need to remove torch.cuda.synchronize() and test FULL CG Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
- Removed torch.cuda.synchronize() call since kernel launches are ordered on the same stream - Removed debug print statements - Added comment explaining GPU tensor passed as query_start_loc_cpu - Changed pages_per_vb to use full width (clamping handled in kernel) Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
- Test basic correctness with various batch sizes and sequence lengths - Test edge cases (empty batch, single request, max virtual batches) - Test large batch scenarios (100+ virtual batches) - Test block table correctness and indexing invariants - Test output buffer invariants (monotonicity, bounds) - All 33 tests pass Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
- Uniform single token decode case: reuse input query_start_loc_cpu directly since num_vb == batch_size and cu_seqlens are identical - Spec-decode / Prefill case: use pinned memory with non-blocking copy so backends like FlashAttn that don't need query_start_loc_cpu can continue building metadata asynchronously - No CPU<>GPU sync needed in decode path (FULL CG case) Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Reagan Lee <reaganjlee@gmail.com> Signed-off-by: Reagan <reaganjlee@gmail.com>
…llm-project#31590) Signed-off-by: c0de128 <kevin.mckay@outlook.com>
…educing agent pool size (vllm-project#31553) Signed-off-by: Andreas Karatzas <akaratza@amd.com>
…t#31569) Signed-off-by: zhima771 <15836938703@163.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
…oject#31513) Signed-off-by: Jay Hemnani <jayhemnani9910@gmail.com> Co-authored-by: Jay Hemnani <jayhemnani9910@gmail.com> Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
…lm-project#31572) Signed-off-by: Tmn07 <tmn0796@gmail.com>
…#31549) Signed-off-by: vaibhav sourirajan <vs2787@columbia.edu>
…code (vllm-project#31282) Signed-off-by: c0de128 <kevin.mckay@outlook.com>
Signed-off-by: Kyuyeun Kim <kyuyeunk@google.com>
Signed-off-by: xiaoming <1259730330@qq.com>
Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com>
…#30739) Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: njhill <nickhill123@gmail.com>
…ery inputs (vllm-project#31596) Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
…-project#31504) Signed-off-by: Yongye Zhu <zyy1102000@gmail.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
…roject#31604) Signed-off-by: Alfred <massif0601@gmail.com>
…_config (vllm-project#28454) Signed-off-by: Xingyu Liu <charlotteliu12x@gmail.com> Signed-off-by: Xingyu Liu <38244988+charlotte12l@users.noreply.github.com>
…ring (vllm-project#29255) Signed-off-by: Jeremy Teboul <jeremyteboul@fb.com> Co-authored-by: Jeremy Teboul <jeremyteboul@fb.com>
…roject#31630) Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com> Signed-off-by: Robert Shaw <robertgshaw2@gmail.com> Signed-off-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
…ject#31654) Signed-off-by: Reagan <reaganjlee@gmail.com>
Signed-off-by: zRzRzRzRzRzRzR <2448370773@qq.com>
…m-project#31137) Signed-off-by: Andy Xie <andy.xning@gmail.com>
|
This pull request has merge conflicts that must be resolved before it can be |
|
rebase needed; closing to avoid codeowner thrash |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a wide range of changes, including significant refactorings, new features, bug fixes, and performance improvements. Key changes include the introduction of ModelArchitectureConfig for better model property management, a major refactoring of MoE and local attention kernels, and the deprecation of current_platform.seed_everything in favor of a centralized set_random_seed. My review focuses on a critical bug in the GLM4 reasoning parser refactoring and an inconsistency in the usage of the newly introduced set_random_seed function.
| class Glm4MoeModelReasoningParser(Holo2ReasoningParser): | ||
| """ | ||
| Reasoning parser for the Glm4MoeModel model. | ||
|
|
||
| The Glm4MoeModel model uses <think>...</think> tokens to denote reasoning | ||
| text within its output. The model provides a strict switch to disable | ||
| reasoning output via the 'enable_thinking=False' parameter. This parser | ||
| extracts the reasoning content enclosed by <think> and </think> tokens | ||
| from the model's output. | ||
| Reasoning parser for the Glm4MoeModel model,which inherits from | ||
| `Holo2ReasoningParser`. | ||
| """ | ||
|
|
||
| def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs): | ||
| super().__init__(tokenizer, *args, **kwargs) | ||
| self.think_start_token = "<think>" | ||
| self.think_end_token = "</think>" | ||
| self.assistant_token = "<|assistant|>" | ||
|
|
||
| if not self.model_tokenizer: | ||
| raise ValueError( | ||
| "The model tokenizer must be passed to the ReasoningParser " | ||
| "constructor during construction." | ||
| ) | ||
|
|
||
| self.think_start_token_id = self.vocab.get(self.think_start_token) | ||
| self.think_end_token_id = self.vocab.get(self.think_end_token) | ||
| self.assistant_token_id = self.vocab.get(self.assistant_token) | ||
| if ( | ||
| self.think_start_token_id is None | ||
| or self.think_end_token_id is None | ||
| or self.assistant_token_id is None | ||
| ): | ||
| raise RuntimeError( | ||
| "Glm4MoeModel reasoning parser could not locate " | ||
| "think start/end or assistant tokens in the tokenizer!" | ||
| ) | ||
|
|
||
| def is_reasoning_end(self, input_ids: list[int]) -> bool: | ||
| """ | ||
| GLM's chat template has <think></think> tokens after every | ||
| <|assistant|> token. Thus, we need to check if </think> is | ||
| after the most recent <|assistant|> token (if present). | ||
| """ | ||
| for token_id in input_ids[::-1]: | ||
| if token_id == self.think_end_token_id: | ||
| return True | ||
| elif token_id == self.assistant_token_id: | ||
| return False | ||
| return False | ||
|
|
||
| def extract_content_ids(self, input_ids: list[int]) -> list[int]: | ||
| """ | ||
| Extract the content after the end tokens | ||
| """ | ||
| if self.think_end_token_id not in input_ids[:-1]: | ||
| return [] | ||
| else: | ||
| return input_ids[input_ids.index(self.think_end_token_id) + 1 :] | ||
|
|
||
| def extract_reasoning_streaming( | ||
| self, | ||
| previous_text: str, | ||
| current_text: str, | ||
| delta_text: str, | ||
| previous_token_ids: Sequence[int], | ||
| current_token_ids: Sequence[int], | ||
| delta_token_ids: Sequence[int], | ||
| ) -> DeltaMessage | None: | ||
| """ | ||
| Extract reasoning content from a delta message. | ||
| Handles streaming output where previous + delta = current. | ||
| Uses token IDs for faster processing. | ||
| For text <think>abc</think>xyz: | ||
| - 'abc' goes to reasoning | ||
| - 'xyz' goes to content | ||
| """ | ||
| # Skip single special tokens | ||
| if len(delta_token_ids) == 1 and ( | ||
| delta_token_ids[0] in [self.think_start_token_id, self.think_end_token_id] | ||
| ): | ||
| return None | ||
|
|
||
| if self.think_start_token_id in previous_token_ids: | ||
| if self.think_end_token_id in delta_token_ids: | ||
| # <think> in previous, </think> in delta, | ||
| # extract reasoning content | ||
| end_index = delta_text.find(self.think_end_token) | ||
| reasoning = delta_text[:end_index] | ||
| content = delta_text[end_index + len(self.think_end_token) :] | ||
| return DeltaMessage( | ||
| reasoning=reasoning, | ||
| content=content if content else None, | ||
| ) | ||
| elif self.think_end_token_id in previous_token_ids: | ||
| # <think> in previous, </think> in previous, | ||
| # reasoning content continues | ||
| return DeltaMessage(content=delta_text) | ||
| else: | ||
| # <think> in previous, no </think> in previous or delta, | ||
| # reasoning content continues | ||
| return DeltaMessage(reasoning=delta_text) | ||
| elif self.think_start_token_id in delta_token_ids: | ||
| if self.think_end_token_id in delta_token_ids: | ||
| # <think> in delta, </think> in delta, extract reasoning content | ||
| start_index = delta_text.find(self.think_start_token) | ||
| end_index = delta_text.find(self.think_end_token) | ||
| reasoning = delta_text[ | ||
| start_index + len(self.think_start_token) : end_index | ||
| ] | ||
| content = delta_text[end_index + len(self.think_end_token) :] | ||
| return DeltaMessage( | ||
| reasoning=reasoning, | ||
| content=content if content else None, | ||
| ) | ||
| else: | ||
| # <think> in delta, no </think> in delta, | ||
| # reasoning content continues | ||
| return DeltaMessage(reasoning=delta_text) | ||
| else: | ||
| # thinking is disabled, just content | ||
| return DeltaMessage(content=delta_text) | ||
|
|
||
| def extract_reasoning( | ||
| self, model_output: str, request: ChatCompletionRequest | ||
| ) -> tuple[str | None, str | None]: | ||
| """ | ||
| Extract reasoning content from the model output. | ||
|
|
||
| For text <think>abc</think>xyz: | ||
| - 'abc' goes to reasoning | ||
| - 'xyz' goes to content | ||
|
|
||
| Returns: | ||
| tuple[Optional[str], Optional[str]]: reasoning content and content | ||
| """ | ||
|
|
||
| # Check if the model output contains the <think> and </think> tokens. | ||
| if ( | ||
| self.think_start_token not in model_output | ||
| or self.think_end_token not in model_output | ||
| ): | ||
| return None, model_output | ||
| # Check if the <think> is present in the model output, remove it | ||
| # if it is present. | ||
| model_output_parts = model_output.partition(self.think_start_token) | ||
| model_output = ( | ||
| model_output_parts[2] if model_output_parts[1] else model_output_parts[0] | ||
| ) | ||
| # Check if the model output contains the </think> tokens. | ||
| # If the end token is not found, return the model output as is. | ||
| if self.think_end_token not in model_output: | ||
| return None, model_output | ||
|
|
||
| # Extract reasoning content from the model output. | ||
| reasoning, _, content = model_output.partition(self.think_end_token) | ||
|
|
||
| final_content = content or None | ||
| return reasoning, final_content | ||
| pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This refactoring to inherit from Holo2ReasoningParser introduces a critical bug. Holo2ReasoningParser uses DeepSeekR1ReasoningParser internally, which is designed for DeepSeek's <|think|> and <|end_think|> tokens. However, GLM4 models use <think> and </think>. This change breaks reasoning parsing for GLM4. The original implementation specific to GLM4 should be restored to correctly handle its reasoning tokens.
| ) | ||
| from vllm.platforms import CpuArchEnum, current_platform | ||
| from vllm.utils.argparse_utils import FlexibleArgumentParser | ||
| from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To align with the deprecation of current_platform.seed_everything, please import set_random_seed here. This will be used to replace the deprecated function call in main.
| from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE | |
| from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE, set_random_seed |
| seed: int = 0, | ||
| iters: int = 20, | ||
| ) -> None: | ||
| current_platform.seed_everything(seed) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Purpose
Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.