[Kernel] FlashMLA integration#13747
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
8ce72c7 to
3ae4a6e
Compare
| elif block_size != 64: | ||
| logger.warning( | ||
| "FlashMLA backend is not supported for block size %d" | ||
| " (currently only supports block size 64).", | ||
| block_size) |
There was a problem hiding this comment.
you can update the config here:
Line 111 in 4a8cfc7
check the env var and change block size (with an info level logging message).
There was a problem hiding this comment.
Done, this helps thanks! We will still need to figure out a better solution if we ever wanted to make it default though since this relies on the env var
There was a problem hiding this comment.
theoretically, you can set env var inside this function. and it should be respected later.
There was a problem hiding this comment.
for now, maybe using env var is fine, for people to try it out, before turning it on by default.
There was a problem hiding this comment.
ya id like to do some cursory benchmarking for a few different workloads before turning it on by default 👍, but I suspect we will ultimately turn it on be default in the next couple days since it should be much faster than triton
|
1xH100 |
Looks like the FlashMLA has higher throughput (5%-10%) but trades-off the latency (1%). BTW can you post the GPU model used in this test? |
note that this is a small model |
I'd like to know if this patch already work for 2 8xH100? And anybody can do r1 671b benchmark on it? |
mgoin
left a comment
There was a problem hiding this comment.
Nice work!! I think the performance benefit should be greatest at very large seq len
|
8xH200 |
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
|
Hi @LucasWilkinson Great pr! I am trying to reproduce the number on local environment, but hit this following installation issue: I am using image |
can you please provide the output of |
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
@LucasWilkinson thanks for great work, I found that flashmla improved by about 10% when doing throughput testing, but why did the Output token throughput of triton_mla and flash_mla not improve when doing latency testing? Thanks a lot! |
I think there is minor issue in command. We should use |
I observed the error: And then I reset to 145944c but still saw this issue. It is not related to this PR. Check similar issue #5587 |
|
@LucasWilkinson hi, When the prefix cache is enabled and send two identical requests, once the cache hits, this error will be reported. |
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
I think thats related to this PR: #12639 not the FlashMLA one, but ill investigate and open a PR to disable prefix caching + MLA in the meantime, thanks for the report! |
Thanks a lot for your reply! Looking forward to having prefix cache supported in MLA! |
|
My result on H200 VLLM_ATTENTION_BACKEND=TRITON_MLA python benchmarks/benchmark_throughput.py --model /home/vllm-dev/DeepSeek-R1 --trust-remote-code --load-format dummy --input-len 2000 --output-len 1000 --num-prompts 60 -tp 8 VLLM_ATTENTION_BACKEND=FLASHMLA python benchmarks/benchmark_throughput.py --model /home/vllm-dev/DeepSeek-R1 --trust-remote-code --load-format dummy --input-len 2000 --output-len 1000 --num-prompts 60 -tp 8 Throughput: 1.41 requests/s, 4235.21 total tokens/s, 1411.74 output tokens/s VLLM_ATTENTION_BACKEND=TRITON_MLA python benchmarks/benchmark_throughput.py --model /home/vllm-dev/DeepSeek-R1 --trust-remote-code --load-format dummy --input-len 5000 --output-len 1000 --num-prompts 60 -tp 8 VLLM_ATTENTION_BACKEND=FLASHMLA python benchmarks/benchmark_throughput.py --model /home/vllm-dev/DeepSeek-R1 --trust-remote-code --load-format dummy --input-len 5000 --output-len 1000 --num-prompts 60 -tp 8 VLLM_ATTENTION_BACKEND=TRITON_MLA python benchmarks/benchmark_throughput.py --model /home/vllm-dev/DeepSeek-R1 --trust-remote-code --load-format dummy --input-len 10000 --output-len 1000 --num-prompts 60 -tp 8 VLLM_ATTENTION_BACKEND=FLASHMLA python benchmarks/benchmark_throughput.py --model /home/vllm-dev/DeepSeek-R1 --trust-remote-code --load-format dummy --input-len 10000 --output-len 1000 --num-prompts 60 -tp 8 |
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
hi @billishyahao, have you successfully run flashmla in |
Hi @ZhongYingMatrix , we observe the same symptom. vllm openai image is buggy. cuda nvcc is unintentionally being downgraded to cuda 12.1 rather than 12.4 (base cuda image). Just for quick workaround, I would recommend you to try |
|
FYI that vLLM is upgrading to 12.4 as the default in the next release (v0.8.0) |
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>


Integrate: https://github.com/deepseek-ai/FlashMLA
currently requires
and
TODO:
cuda-graphs are brokenCloses #13735