You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
We are excited to announce the Day 0 support of DeepSeek-V3.2-Exp (LINK TO MODEL ON HF), featuring DeepSeek Sparse Attention (DSA) (LINK TO PAPER) designed for long context tasks. In this post, we showcase how to use this model in vLLM and deep dive into the challenges encountered in supporting DSA in vLLM.
12
+
We are excited to announce Day 0 support for [DeepSeek-V3.2-Exp](https://huggingface.co/deepseek-ai/DeepSeek-V3.2-Exp), featuring DeepSeek Sparse Attention (DSA) ([paper](https://github.com/deepseek-ai/DeepSeek-V3.2-Exp/blob/main/DeepSeek_V3_2.pdf)) designed for long context tasks. In this post, we showcase how to use this model in vLLM and dive deep into the challenges encountered in supporting DSA in vLLM.
13
13
14
-
In particular, DSA’s lightning indexer indexer along with sparse attention presents challenges in continuous batching and paged attention. For example, we need to take care of prefill and decode separately for the indexer module and carefully manage the different cache layouts.
14
+
In particular, DSA's lightning indexer along with sparse attention presents challenges in continuous batching and paged attention. For example, we need to handle prefill and decode separately for the indexer module and carefully manage the different cache layouts.
15
15
16
-
On the performance side, vLLM integrates with the lightning indexer CUDA kernels in DeepGEMM, as well as the new sparse attention kernel in FlashMLA. We are also excited about the Blackwell support. In collaboration with NVIDIA, you can run this model directly on B200 and GB200\!
16
+
On the performance side, vLLM integrates with the lightning indexer CUDA kernels in DeepGEMM, as well as the new sparse attention kernel in FlashMLA. We are also excited about Blackwell support. In collaboration with NVIDIA, you can run this model directly on B200 and GB200!
To get started with DeepSeek 3.2, please follow the installation instructions in the [recipes](https://docs.vllm.ai/projects/recipes/en/latest/DeepSeek/DeepSeek-V3_2-Exp.html). We are still improving the initial support with [PR](https://github.com/vllm-project/vllm/pull/25869) and for known issues see [tracking issue](https://github.com/vllm-project/vllm/issues/25877).
28
+
To get started with DeepSeek 3.2, please follow the installation instructions in the [recipes](https://docs.vllm.ai/projects/recipes/en/latest/DeepSeek/DeepSeek-V3_2-Exp.html). We are still improving the initial support with this [PR](https://github.com/vllm-project/vllm/pull/25869). For known issues, see the[tracking issue](https://github.com/vllm-project/vllm/issues/25877).
29
29
30
-
Once installed, on 16xH100, 8xH200 or 8xB200, you can run the model with tensor parallelism (expert parallelism has a slight bug we are fixing):
30
+
Once installed, on 16×H100, 8×H200, or 8×B200, you can run the model with tensor parallelism (expert parallelism has a slight bug we are fixing):
To deploy at scale, we look forward to sharing our oneclick Kubernetes deployment using `llm-d` later this week. This approach launches vLLM with PD disaggregation with NIXL, then for each P and D instance, efficiently route requests to different data parallel ranks. See documentation here.
36
+
To deploy at scale, we look forward to sharing our one-click Kubernetes deployment using `llm-d` later this week. This approach launches vLLM with PD disaggregation using NIXL, then for each P and D instance, efficiently routes requests to different data parallel ranks. Documentation will be available soon.
37
37
38
-
Once you start the engine, we recommend testing it with *long input or prompt expecting long output*. We recommend comparing it with V3.1-Terminus as it is continuously pre-trained on the same data mix.
38
+
Once you start the engine, we recommend testing it with *long input or prompts expecting long output*. We recommend comparing it with V3.1-Terminus as it is continuously pre-trained on the same data mix.
39
39
40
-
We are still in process of verifying vLLM’s implementation against the official accuracy result. On a previous version of the model weights, we matched expected GSM8K and GPQA-Diamond scores, and showcase it is similar to V3.1-Terminus.
40
+
We are still in the process of verifying vLLM's implementation against the official accuracy results. On a previous version of the model weights, we matched the expected GSM8K and GPQA-Diamond scores, and showed it is similar to V3.1-Terminus.
41
41
42
42
### Implementation of Top-K Sparse Attention in vLLM
43
43
44
44
#### New Cache Entry and Quantization Scheme
45
45
46
-
The lightning indexer module has cached K values specifically for indexing. This means for each token, there’s now another K cache used by the indexer. vLLM allocates separate buffers to save the indexer K cache separated from MLA K cache.
46
+
The lightning indexer module has cached K values specifically for indexing. This means that for each token, there is now another K cache used by the indexer. vLLM allocates separate buffers to save the indexer K cache, separate from the MLA K cache.
One more interesting point is the handling of the FP8 KV cache, which this model supports. For MLA, each token's KV cache is 656 Bytes, structured as:
54
+
Another interesting point is the handling of the FP8 KV cache, which this model supports. For MLA, each token's KV cache is 656 bytes, structured as:
55
55
56
56
* First 512 bytes: The "quantized NoPE" part, containing 512 `float8_e4m3` values.
57
57
* Next 16 bytes: Scale factors, containing 4 `float32` values. The first `float32` is the scale for the first 128 `float8_e4m3` values, the second for the next 128, and so on.
58
58
* Last 128 bytes: The "RoPE" part, containing 64 `bfloat16` values. This part is not quantized for accuracy.
59
59
60
-
But for indexer key cache, it is stored as a perblock basis. This is one of the reasons we only support block size 64 for this model; the other being FlashMLA is tailored to it as well. The first `block_size * head_dim` entries contains the value, the rest contains the scaling factor:
60
+
However, for the indexer key cache, it is stored on a per-block basis. This is one of the reasons we only support block size 64 for this model; the other being that FlashMLA is tailored to it as well. The first `block_size * head_dim` entries contain the value, the rest contain the scaling factor:
In the indexer, the cache of one token is not stored contiguously.
67
+
In the indexer, the cache for one token is not stored contiguously.
68
68
69
69
#### New Computation with Masking
70
70
71
-
For each new query token, it now passes through the indexer to compute top 2048 tokens to attend to. For a query of a token is a tensor of shape `(h, d)`, with `h` being the number of query heads, and `d` being the head dimension. The context of size `n` is a 2D tensor of shape `(n, d)`. The computed logits (relevance score between the query and the context) are a tensor of shape `(n, h)`. Weighting the logits by the head weights of shape `(h,)`, we get a tensor of shape `(n,)`. We need to produce a `(2048,)` integer tensor of the indices of the top-2048 tokens, with `-1` filled for the rest if there are less than 2048 tokens.
71
+
For each new query token, it now passes through the indexer to compute the top 2048 tokens to attend to. A query for a token is a tensor of shape `(h, d)`, with `h` being the number of query heads, and `d` being the head dimension. The context of size `n` is a 2D tensor of shape `(n, d)`. The computed logits (relevance scores between the query and the context) are a tensor of shape `(n, h)`. Weighting the logits by the head weights of shape `(h,)`, we get a tensor of shape `(n,)`. We need to produce a `(2048,)` integer tensor of the indices of the top-2048 tokens, with `-1` filled for the rest if there are fewer than 2048 tokens.
72
72
73
-
While it’s straightforward to see how a single query token selects indices to attend to. The batching case is more complicated, let’s break it down.
73
+
While it's straightforward to see how a single query token selects indices to attend to, the batching case is more complicated. Let's break it down.
74
74
75
-
The new DeepGemm function is called like the following
For several query tokens (length `q`) from the same request (i.e. the prefill case). They are stored in a tensor of shape `(q, h, d)`. The context still has `n` tokens, so the context is still a 2D tensor of shape `(n, d)`. The logits are a tensor of shape `(q, n, h)`. Weighting the logits by the head weights, we get a tensor of shape `(q, n)`. We need to produce a `(q, 2048)` integer tensor of the indices of the top-2048 tokens. Due to causality, every query token only attends to the tokens before it. We need to mark the start context and the end context for each query token. We use `ks` to mark the start context, and `ke` to mark the end context. `ks` and `ke` are both `(q,)`shaped integer tensors. In this case, `ks` will be all zeros, and `ke` will be a`list(range(n - q, n, 1))`.
81
+
For several query tokens (length `q`) from the same request (i.e., the prefill case), they are stored in a tensor of shape `(q, h, d)`. The context still has `n` tokens, so the context is still a 2D tensor of shape `(n, d)`. The logits are a tensor of shape `(q, n, h)`. Weighting the logits by the head weights, we get a tensor of shape `(q, n)`. We need to produce a `(q, 2048)` integer tensor of the indices of the top-2048 tokens. Due to causality, every query token only attends to the tokens before it. We need to mark the start context and the end context for each query token. We use `ks` to mark the start context, and `ke` to mark the end context. `ks` and `ke` are both `(q,)`-shaped integer tensors. In this case, `ks` will be all zeros, and `ke` will be `list(range(n - q, n, 1))`.
82
82
83
-
Finally, let's consider how to batch multiple requests. We have `b` requests, each request has `q1, q2, ..., qb` query tokens, and `n1, n2, ..., nb` context tokens. The query tokens will be batched into a tensor of shape `(q1 + q2 + ... + qb, h, d`). The context will be batched into a tensor of shape `(n1 + n2 + ... + nb, d)`. The logits will be batched into a tensor of shape `(q1 + q2 + ... + qb, n1 + n2 + ... + nb, h)`. We need to produce a `(q1 + q2 + ... + qb, 2048)` integer tensor of the indices of the top-2048 tokens.
83
+
Finally, let's consider how to batch multiple requests. We have `b` requests, each request has `q1, q2, ..., qb` query tokens, and `n1, n2, ..., nb` context tokens. The query tokens will be batched into a tensor of shape `(q1 + q2 + ... + qb, h, d)`. The context will be batched into a tensor of shape `(n1 + n2 + ... + nb, d)`. The logits will be batched into a tensor of shape `(q1 + q2 + ... + qb, n1 + n2 + ... + nb, h)`. We need to produce a `(q1 + q2 + ... + qb, 2048)` integer tensor of the indices of the top-2048 tokens.
84
84
85
-
We need to mark the start context and the end context for each query token. We use \`ks\` to mark the start context, and `ke` to mark the end context. ks and ke are both `(q1 + q2 + ... + qb,)`shaped integer tensors.
85
+
We need to mark the start context and the end context for each query token. We use `ks` to mark the start context, and `ke` to mark the end context. `ks` and `ke` are both `(q1 + q2 + ... + qb,)`-shaped integer tensors.
86
86
87
-
In this case, ks will be `[0] * q1 + [q1] * q2 + ... + [q1 + q2 + ... + qb] * qb`. Here \* means repeating the list. ke will be `list(range(n1 - q1, n1, 1)) + list(range(n2 - q2, n2, 1)) + ... + list(range(nb - qb, nb, 1))` plus the offset of ks.
87
+
In this case, `ks` will be `[0] * q1 + [q1] * q2 + ... + [q1 + q2 + ... + qb] * qb`. Here `*` means repeating the list. `ke` will be `list(range(n1 - q1, n1, 1)) + list(range(n2 - q2, n2, 1)) + ... + list(range(nb - qb, nb, 1))` plus the offset of `ks`.
88
88
89
-
After the logits, we need to perform the `topk` operation. However, a clear challenge is at high batch size with long context, the logits tensor is materialized before running a row-wise `topk`. The vLLM team is working on a fused version inspired by FlashAttention, we can run an online topk in a way we don’t need to materialize the intermediate logits.
89
+
After computing the logits, we need to perform the `topk` operation. However, a clear challenge is that at high batch size with long context, the logits tensor is materialized before running a row-wise `topk`. The vLLM team is working on a fused version inspired by FlashAttention, so we can run an online topk in a way that we don't need to materialize the intermediate logits.
90
90
91
91
#### Fusion pass, more kernels, and Blackwell Support
92
92
93
-
As we starting to optimize the performance, we started with a few lowhanging fruit:
93
+
As we started to optimize the performance, we began with a few low-hanging fruit:
94
94
95
-
* TopK can be expressed with a fused kernel with better performance. The TileLang kernel from the DeepSeek team serves as great references\!
96
-
* We used the quantization of MLA latent and indexer key vectors as they are writing to vLLM’s page table. This turns out to be non-trivial as we previously explained that the quantization scheme is new and different.
95
+
* Top-K can be expressed with a fused kernel for better performance. The TileLang kernel from the DeepSeek team serves as a great reference!
96
+
* We used the quantization of MLA latent and indexer key vectors as they are written to vLLM's page table. This turns out to be non-trivial, as we previously explained that the quantization scheme is new and different.
97
97
98
-
We are also excited to announce the out of thebox Blackwell support for this model. We strive to make Blackwell platform a firstclass citizen in model releases going forward, as its efficiency helps bring out the best performance\!
98
+
We are also excited to announce out-of-the-box Blackwell support for this model. We strive to make the Blackwell platform a first-class citizen in model releases going forward, as its efficiency helps bring out the best performance!
99
99
100
100
### Ongoing Work
101
101
102
-
We are barely touching the surface of the optimization for DSA and related sparse attention in vLLM. In coming weeks,
102
+
We are barely touching the surface of the optimization for DSA and related sparse attention in vLLM. In the coming weeks:
103
103
104
-
* We plan to expand the architectures supported beyond Hopper and Blackwell.
105
-
* We will bring in AMD support as well.
106
-
* We continuously test largescale wide EP serving and disaggregation.
107
-
* You will soon be able to run an end to end RL loop with this model.
108
-
* We will explore the “masked MHA mode for short sequence prefilling” from DeepSeek
109
-
* In this release, we removed Hadamard transforms as we observe no effect on accuracy. We will investigate further\!
104
+
* We plan to expand the architectures supported beyond Hopper and Blackwell.
105
+
* We will bring in AMD support as well.
106
+
* We continuously test large-scale wide EP serving and disaggregation.
107
+
* You will soon be able to run an end-to-end RL loop with this model.
108
+
* We will explore the "masked MHA mode for short sequence prefilling" from DeepSeek.
109
+
* In this release, we removed Hadamard transforms as we observed no effect on accuracy. We will investigate further!
110
110
111
-
### Acknowledgement
111
+
### Acknowledgements
112
112
113
-
The following team in the vLLM community worked on this model’s support:
113
+
The following teams in the vLLM community worked on supporting this model:
114
114
115
-
* vLLM core: Chen Zhang, Yongye Zhu, Kaichao You, Simon Mo, Zhuohan Li
116
-
* Red Hat: Lucas Wilkinson, Matt Bonanni, Wentao Ye, Nicolo Lucchesi, Michael Goin, Robert Shaw, Tyler Michael Smith
117
-
* Meta: Lucia Fang, Xiaozhu Meng, Lu Fang
115
+
* vLLM core: Chen Zhang, Yongye Zhu, Kaichao You, Simon Mo, Zhuohan Li
116
+
* Red Hat: Lucas Wilkinson, Matt Bonanni, Wentao Ye, Nicolo Lucchesi, Michael Goin, Robert Shaw, Tyler Michael Smith
117
+
* Meta: Lucia Fang, Xiaozhu Meng, Lu Fang
118
118
* NVIDIA: Ray Wang, Barry Kang, Daniel Campora, Julien Demouth, Siyuan Fu, Zeyu Wang, Pen Chun Li
119
119
120
-
As the vLLM team, we want to thank DeepSeek team for opensourcing this model, techniques, and kernels, as well as DeepSeek leadership for trust and support in vLLM\!
120
+
As the vLLM team, we want to thank the DeepSeek team for open-sourcing this model, techniques, and kernels, as well as DeepSeek leadership for their trust and support in vLLM!
0 commit comments