Add env variable VLLM_DISABLE_FLASHINFER_CONCAT_MLA_K to disable FlashInfer concat_mla_k#35016
Add env variable VLLM_DISABLE_FLASHINFER_CONCAT_MLA_K to disable FlashInfer concat_mla_k#35016maazmusameta wants to merge 1 commit intovllm-project:mainfrom
Conversation
…hInfer concat_mla_k Summary: Add an environment variable check to allow disabling the FlashInfer concat_mla_k kernel optimization. Setting VLLM_DISABLE_FLASHINFER_CONCAT_MLA_K=1 will bypass this optimization, which is useful for debugging or when replaying components on CUDA where FlashInfer may not work correctly. Test Plan: VLLM_DISABLE_FLASHINFER_CONCAT_MLA_K=1 buck2 run //vllm:test_mla_attention Differential Revision: D93967992
There was a problem hiding this comment.
Code Review
This pull request introduces an environment variable VLLM_DISABLE_FLASHINFER_CONCAT_MLA_K to disable the FlashInfer concat_mla_k kernel optimization, which is a useful addition for debugging purposes. The implementation is straightforward and correct. My only suggestion is to centralize the environment variable handling by using the vllm.envs module, which is the standard pattern in this codebase. This will improve code consistency and maintainability.
| # num_heads=128, nope_dim=128, rope_dim=64 | ||
| self._use_flashinfer_concat_mla_k = ( | ||
| has_flashinfer() | ||
| and os.environ.get("VLLM_DISABLE_FLASHINFER_CONCAT_MLA_K", "0") != "1" |
There was a problem hiding this comment.
For consistency with how other vLLM-specific environment variables are handled, it would be better to manage VLLM_DISABLE_FLASHINFER_CONCAT_MLA_K through the vllm.envs module. This centralizes environment variable management and makes the code cleaner.
You can add the new environment variable to vllm/envs.py like this:
# In vllm/envs.py
'VLLM_DISABLE_FLASHINFER_CONCAT_MLA_K': lambda: os.getenv("VLLM_DISABLE_FLASHINFER_CONCAT_MLA_K", "0") == "1",Then, you can use it here as suggested. With this change, the import os at the top of this file is no longer needed and can be removed.
| and os.environ.get("VLLM_DISABLE_FLASHINFER_CONCAT_MLA_K", "0") != "1" | |
| and not envs.VLLM_DISABLE_FLASHINFER_CONCAT_MLA_K |
Summary:
Add an environment variable check to allow disabling the FlashInfer
concat_mla_k kernel optimization. Setting VLLM_DISABLE_FLASHINFER_CONCAT_MLA_K=1
will bypass this optimization, which is useful for debugging or when replaying
components on CUDA where FlashInfer may not work correctly.
Test Plan: VLLM_DISABLE_FLASHINFER_CONCAT_MLA_K=1 buck2 run //vllm:test_mla_attention
Differential Revision: D93967992