Optimize the inference performance of the FLA operator On Qwen3.5 Model#7597
Optimize the inference performance of the FLA operator On Qwen3.5 Model#7597mikequan0425 wants to merge 2 commits intovllm-project:mainfrom
Conversation
清理调试用的打印语句和已被注释的代码,保持代码整洁
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request focuses on optimizing the inference performance of the FLA (Fused Linear Attention) operator, particularly for the Qwen3.5 model, by enhancing how chunk-related metadata is prepared and managed. The core improvement involves introducing a robust caching mechanism within the fla_utils module, which prevents redundant computations of chunk indices and offsets. This refactoring, coupled with the integration of new utility functions across various FLA kernels, aims to reduce computational overhead and improve overall efficiency during model inference. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request significantly optimizes the Flash-Linear-Attention (FLA) operator by refactoring and caching chunk metadata preparation. The changes replace inefficient Python-level loops with vectorized PyTorch operations, and introduce a caching layer to avoid redundant computations, which also removes expensive CPU-GPU synchronizations. The PR also includes new unit tests to validate the correctness of these optimizations.
My main feedback is to address a potential memory leak in the new caching implementation by adding a size limit to the cache.
Following the repository's style guide, I've also suggested an updated PR title and summary.
Suggested PR Title:
[Ops][Feature] Optimize FLA operator performance by caching metadataSuggested PR Summary:
### What this PR does / why we need it?
This PR optimizes the Flash-Linear-Attention (FLA) operator by improving the performance of chunk metadata preparation.
- The utility functions in `vllm_ascend/ops/triton/fla/utils.py` are refactored to use efficient, vectorized PyTorch operations instead of Python loops and list comprehensions.
- A caching mechanism is introduced for these utility functions to avoid recomputing metadata for the same input tensors. The cache uses the tensor's ID and version for keying.
- This change avoids expensive CPU-GPU synchronization that was present in the previous implementation.
- Other FLA operator files are updated to use the new optimized utility functions.
- Unit tests are added to verify the new utility functions and their caching logic.
This optimization improves the overall inference performance of models using the FLA operator.
### Does this PR introduce _any_ user-facing change?
No. This is a performance optimization and should not change any user-facing behavior.
### How was this patch tested?
- Added new unit tests in `tests/ut/ops/test_fla_utils.py` to verify the correctness of the refactored utility functions and the caching mechanism.
- CI passed with new and existing tests.| def _cache_prepare_result(cu_seqlens: torch.LongTensor, chunk_size: int, name: str, value): | ||
| key = _get_prepare_cache_key(cu_seqlens, chunk_size, name) | ||
| _PREPARE_CACHE[key] = (weakref.ref(cu_seqlens), value) | ||
| return value |
There was a problem hiding this comment.
The _PREPARE_CACHE dictionary is unbounded and can grow indefinitely if many different cu_seqlens tensors are used throughout the lifetime of the application. This can lead to a memory leak. Although weakref is used for the tensor, the cache entries for garbage-collected tensors are not proactively removed. They are only removed if a new tensor happens to reuse the same id, which is not a reliable cleanup mechanism.
To prevent potential out-of-memory errors, you should bound the cache size. A simple approach is to evict items when the cache exceeds a certain threshold.
For example, you could implement a simple FIFO eviction policy:
# At module level
_PREPARE_CACHE_MAX_SIZE = 256
# In _cache_prepare_result
def _cache_prepare_result(cu_seqlens: torch.LongTensor, chunk_size: int, name: str, value):
if len(_PREPARE_CACHE) >= _PREPARE_CACHE_MAX_SIZE:
# Evict the first item inserted (FIFO). Requires Python 3.7+ for dict insertion order.
_PREPARE_CACHE.pop(next(iter(_PREPARE_CACHE)))
key = _get_prepare_cache_key(cu_seqlens, chunk_size, name)
_PREPARE_CACHE[key] = (weakref.ref(cu_seqlens), value)
return valueA more robust solution would be to use an LRU (Least Recently Used) cache.
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
What this PR does / why we need it?
Does this PR introduce any user-facing change?
How was this patch tested?