[wip][Spec Decoding] Zero-bubble async scheduling + spec decoding#7640
[wip][Spec Decoding] Zero-bubble async scheduling + spec decoding#7640HF-001 wants to merge 3 commits intovllm-project:mainfrom
Conversation
Signed-off-by: 01267596 <xiongkai123@cmbchina.com>
Signed-off-by: 01267596 <xiongkai123@cmbchina.com>
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances speculative decoding performance by introducing zero-bubble asynchronous scheduling. The core improvement lies in optimistically processing draft tokens on the CPU, assuming their acceptance, and then performing necessary corrections on the NPU after the model's forward pass. This approach aims to reduce latency and improve hardware utilization. The changes involve a fundamental shift in how sequence lengths and computed tokens are managed across CPU and GPU, optimizing KV cache slot mapping with a new kernel, and implementing robust deferred state corrections to maintain data consistency in this asynchronous execution model. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
There was a problem hiding this comment.
Code Review
This PR implements significant improvements for asynchronous speculative decoding and NPU (Ascend) specific optimizations within VLLM. It shifts towards a GPU-centric state management for attention metadata, introducing optimistic_seq_lens_cpu for speculative decoding and moving slot mapping computation to a GPU kernel. Additionally, it includes deferred state corrections and re-synchronization for Mamba cache alignment. A critical review comment highlights a potential integer overflow risk by changing num_accepted_tokens_cpu_tensor from torch.int64 to torch.int32.
| # Speculative decoding | ||
| self.num_accepted_tokens_cpu_tensor = torch.ones( | ||
| (max_num_reqs,), dtype=torch.int64, device="cpu", pin_memory=pin_memory | ||
| (max_num_reqs,), dtype=torch.int32, device="cpu", pin_memory=pin_memory |
There was a problem hiding this comment.
Changing the dtype of num_accepted_tokens_cpu_tensor from torch.int64 to torch.int32 could lead to an integer overflow if the number of accepted tokens for a request exceeds the maximum value for a 32-bit signed integer (2,147,483,647). Please confirm that int32 is sufficient for all expected scenarios, or revert to int64 to prevent potential data loss or incorrect behavior.
| (max_num_reqs,), dtype=torch.int32, device="cpu", pin_memory=pin_memory | |
| (max_num_reqs,), dtype=torch.int64, device="cpu", pin_memory=pin_memory |
| if common_attn_metadata.seq_lens_cpu is not None: | ||
| common_attn_metadata.seq_lens_cpu[:batch_size] = common_attn_metadata.seq_lens_cpu[:batch_size] + 1 | ||
| exceeds_mask = common_attn_metadata.seq_lens_cpu[:batch_size] >= self.max_model_len | ||
| common_attn_metadata.seq_lens_cpu[:batch_size].masked_fill_(exceeds_mask, 1) | ||
| if common_attn_metadata.num_computed_tokens_cpu is not None: |
There was a problem hiding this comment.
The addition of if ... is not None checks for common_attn_metadata.seq_lens_cpu and common_attn_metadata.num_computed_tokens_cpu is a critical improvement. This prevents AttributeError in scenarios where these attributes might be None due to the async spec decode logic, ensuring robustness and correctness.
| # Update num_computed_tokens on GPU. In async spec decode, | ||
| # CPU values are optimistic (all drafts accepted). The kernel | ||
| # corrects on GPU using the previous step's | ||
| # valid_sampled_token_count_gpu. Otherwise, just copy from CPU. | ||
| if ( | ||
| self.use_async_spec_decode | ||
| and self.valid_sampled_token_count_gpu is not None | ||
| and prev_req_id_to_index | ||
| ): | ||
| self.prev_positions.copy_to_gpu(num_reqs) | ||
| self.prev_num_draft_tokens.copy_to_gpu() | ||
| cpu_values = self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs].to( | ||
| device=self.device, non_blocking=True | ||
| ) | ||
| update_num_computed_tokens_for_batch_change( | ||
| self.num_computed_tokens, | ||
| self.num_accepted_tokens.gpu[:num_reqs], | ||
| self.prev_positions.gpu[:num_reqs], | ||
| self.valid_sampled_token_count_gpu, | ||
| self.prev_num_draft_tokens.gpu, | ||
| cpu_values, | ||
| ) | ||
| else: | ||
| self.num_computed_tokens[:num_reqs].copy_( | ||
| self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs], | ||
| non_blocking=True, | ||
| ) |
There was a problem hiding this comment.
The logic for conditionally updating self.num_computed_tokens based on use_async_spec_decode is a core part of the asynchronous scheduling. When use_async_spec_decode is enabled, the GPU-side correction using update_num_computed_tokens_for_batch_change is essential for maintaining data consistency between the optimistic CPU state and the authoritative NPU state. This is a critical correctness change for the new speculative decoding approach.
| self.input_batch.block_table.compute_slot_mapping( | ||
| num_reqs, | ||
| self.query_start_loc.gpu[: num_reqs + 1], | ||
| self.positions[:total_num_scheduled_tokens], | ||
| ) |
There was a problem hiding this comment.
The compute_slot_mapping call has been moved and now correctly uses the GPU-side self.positions and self.query_start_loc.gpu. This is a critical change to ensure that the slot mapping is computed based on the most up-to-date and authoritative GPU state, which is essential for the attention mechanism.
| if deferred_state_corrections_fn: | ||
| deferred_state_corrections_fn() | ||
| deferred_state_corrections_fn = None |
There was a problem hiding this comment.
Applying deferred_state_corrections_fn before mamba_utils.preprocess_mamba is a critical correctness fix. preprocess_mamba relies on req_state.num_computed_tokens (CPU), so ensuring these corrections are applied beforehand prevents preprocess_mamba from operating on an outdated or optimistic CPU state.
| if self.use_async_spec_decode: | ||
| # GPU tensors are authoritative in async mode. | ||
| seq_lens_cpu = None | ||
| num_computed_tokens_cpu = None |
There was a problem hiding this comment.
Setting seq_lens_cpu and num_computed_tokens_cpu to None when use_async_spec_decode is enabled is a critical change. This explicitly signals that in async mode, the GPU tensors are authoritative, preventing accidental reliance on potentially optimistic or outdated CPU values in AscendCommonAttentionMetadata. This is crucial for maintaining the integrity of the async scheduling logic.
| self.positions[:total_num_scheduled_tokens] = ( | ||
| self.num_computed_tokens[req_indices_gpu].to(torch.int64) | ||
| + self.query_pos.gpu[:total_num_scheduled_tokens] | ||
| ) | ||
| self.seq_lens[:num_reqs] = ( | ||
| self.num_computed_tokens[:num_reqs] + num_scheduled_tokens_gpu | ||
| ) |
There was a problem hiding this comment.
The calculation of self.positions and self.seq_lens directly on the GPU using self.num_computed_tokens and self.query_pos.gpu is a significant change. This aligns with the strategy of making NPU-side tensors the source of truth and reduces CPU-GPU synchronization overhead. This is a high-severity correctness and performance improvement.
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
Signed-off-by: 01267596 <xiongkai123@cmbchina.com>
What this PR does / why we need it?
refer to: vllm-project/vllm#32951 , It improves the async-ness of spec decoding by optimistically assuming all draft tokens are accepted on the CPU and deferring the correction until after the forward pass. The NPU-side tensors are taken as the source of truth.
At present, the function is normal, but it may be a problem with the Triton operator, resulting in a slight decrease in performance, which is currently being optimized
How was this patch tested?
todo