Skip to content

Commit baeded2

Browse files
LucasWilkinsonWoosukKwonsimon-momgoinzhuohan123
authored
[Attention] Deepseek v3 MLA support with FP8 compute (#12601)
This PR implements the Deepseek V3 support by performing matrix absorption the fp8 weights --------- Signed-off-by: Lucas Wilkinson <[email protected]> Co-authored-by: Woosuk Kwon <[email protected]> Co-authored-by: simon-mo <[email protected]> Co-authored-by: Michael Goin <[email protected]> Co-authored-by: Zhuohan Li <[email protected]> Co-authored-by: Tyler Michael Smith <[email protected]> Co-authored-by: Alexander Matveev <[email protected]>
1 parent 3e1c76c commit baeded2

File tree

10 files changed

+580
-85
lines changed

10 files changed

+580
-85
lines changed

vllm/attention/backends/mla/utils.py

Lines changed: 184 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,29 @@
11
from abc import abstractmethod
22
from dataclasses import dataclass
3-
from typing import Any, Dict, Generic, List, Optional
3+
from typing import Any, Dict, Generic, List, Optional, Tuple
44

55
import torch
6+
from compressed_tensors.quantization import QuantizationStrategy
67

78
from vllm import _custom_ops as ops
89
from vllm import envs
910
from vllm.attention.backends.abstract import (AttentionLayer,
1011
AttentionMetadata,
1112
MLAAttentionImpl, T)
12-
from vllm.distributed import get_tensor_model_parallel_world_size
13+
from vllm.distributed import (get_tensor_model_parallel_world_size,
14+
tensor_model_parallel_all_reduce)
1315
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
14-
RowParallelLinear)
16+
LinearBase, RowParallelLinear,
17+
UnquantizedLinearMethod)
18+
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
19+
CompressedTensorsLinearMethod)
20+
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
21+
CompressedTensorsW8A8Fp8)
22+
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
23+
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
24+
apply_fp8_linear_generic, current_platform_fp8_dtype, is_fp8)
25+
from vllm.model_executor.layers.quantization.utils.quant_utils import (
26+
scaled_dequantize, scaled_quantize)
1527
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
1628
from vllm.vllm_flash_attn import flash_attn_varlen_func
1729

@@ -25,11 +37,11 @@ class MLACommonMetadata(AttentionMetadata):
2537

2638
class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
2739
"""
28-
Common class for implementing repeated parts
29-
40+
Common class for implementing repeated parts
41+
3042
Main reference: DeepseekV2 paper, and FlashInfer Implementation
3143
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
32-
44+
3345
Deepseek's MLA attention works the following way:
3446
* Use a single latent vector to represent the entire KV cache.
3547
* The attention "simulates" a multi-head attention, while the compute is
@@ -46,7 +58,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
4658
* V: V head dim.
4759
* kv_c: latent/compressed KV
4860
* q_c: latent/compressed Q
49-
61+
5062
#
5163
# Outside the MLA attention backend
5264
#
@@ -55,21 +67,21 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
5567
kv_c_k_pe (B, Lkv+R).
5668
2. The kv_c_k_pe is split into kv_c (B, Lkv) and k_pe (B, R). cq
5769
and kv_c are normalized.
58-
70+
5971
#
6072
# Inside the MLA attention backend
6173
#
6274
6375
* if prefill:
64-
65-
3. The q_c is then projected up into the multi-head version.
66-
* q_c goes from (B, Lq) to (B, N, (P+R)), which is split into q_nope
67-
(B, N, P) and q_pe (B, N, R).
76+
77+
3. The q_c is then projected up into the multi-head version.
78+
* q_c goes from (B, Lq) to (B, N, (P+R)), which is split into q_nope
79+
(B, N, P) and q_pe (B, N, R).
6880
4. q_pe, k_pe are then passed through rotary embeddings.
6981
5. kv_c and k_pe are concatenated and inserted into the cache
70-
6. The kv_c is then projected up into the multi-head version.
71-
* kv_c goes from (B, Lkv) to (B, N, (P+V)) which has the nope
72-
dimensions for K and V, which is split into k_nope (B, N, P)
82+
6. The kv_c is then projected up into the multi-head version.
83+
* kv_c goes from (B, Lkv) to (B, N, (P+V)) which has the nope
84+
dimensions for K and V, which is split into k_nope (B, N, P)
7385
and v (B, N, V).
7486
7. q (B, N, (P+R)) and k (B, N, (P+R)) matrices are assembled from
7587
q_nope, q_pe, k_nope, k_pe.
@@ -112,7 +124,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
112124
From @tsu-bin's calculation, we only want to use the absorption technique
113125
for decode. The prefill algorithm should still use the up-projected MHA
114126
for less flops and memory usage.
115-
127+
116128
"""
117129

118130
def __init__(
@@ -162,15 +174,32 @@ def __init__(
162174

163175
def _v_up_proj_and_o_proj(self, x):
164176
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
165-
return self.o_proj_absorbed(
166-
x.reshape(-1, self.num_heads * self.kv_lora_rank))[0]
177+
if is_fp8(self.W_UV_O):
178+
output_parallel = apply_fp8_linear_generic(
179+
x.flatten(start_dim=1), self.W_UV_O, self.W_UV_O_scales,
180+
self.reqaunt_input_group_shape,
181+
self.reqaunt_weight_group_shape)
182+
else:
183+
output_parallel = torch.matmul(x.flatten(start_dim=1),
184+
self.W_UV_O)
185+
if self.tp_size > 1:
186+
output = tensor_model_parallel_all_reduce(output_parallel)
187+
else:
188+
output = output_parallel
189+
return output
167190
else:
168191
x = torch.einsum("bnl,lnv->bnv", x, self.W_UV)
169192
return self.o_proj(x.reshape(-1,
170193
self.num_heads * self.v_head_dim))[0]
171194

172195
def _q_proj_and_k_up_proj(self, x):
173196
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
197+
if is_fp8(self.W_Q_UK):
198+
return apply_fp8_linear_generic(
199+
x, self.W_Q_UK, self.W_Q_UK_scales,
200+
self.reqaunt_input_group_shape,
201+
self.reqaunt_weight_group_shape).view(
202+
-1, self.num_heads, self.kv_lora_rank)
174203
return torch.matmul(x, self.W_Q_UK)\
175204
.view(-1, self.num_heads, self.kv_lora_rank)
176205
else:
@@ -179,8 +208,91 @@ def _q_proj_and_k_up_proj(self, x):
179208
return torch.einsum("bnp,lnp->bnl", x, self.W_UK)\
180209
.view(-1, self.num_heads, self.kv_lora_rank)
181210

182-
def process_weights_after_loading(self):
183-
kv_b_proj_weight = self.kv_b_proj.weight.T
211+
def process_weights_after_loading(self, act_dtype: torch.dtype):
212+
213+
def is_layer_fp8(layer: LinearBase) -> bool:
214+
return isinstance(layer.quant_method, Fp8LinearMethod) or\
215+
(isinstance(layer.quant_method, CompressedTensorsLinearMethod)\
216+
and isinstance(layer.scheme, CompressedTensorsW8A8Fp8))
217+
218+
def quantization_scheme_supported(layer: LinearBase) -> bool:
219+
return isinstance(layer.quant_method, UnquantizedLinearMethod) or \
220+
is_layer_fp8(layer)
221+
222+
# TODO(lucas) This is very gross, we need a more wide scale refactor of
223+
# all the FP8 code with a more standard way of
224+
# defining schemes/group-shapes, we should also potentially force
225+
# quant_methods to support a decompress function
226+
#
227+
# returns input_group_shape, weight_group_shape
228+
def get_scale_group_shapes_for_fp8(layer: LinearBase) -> \
229+
Tuple[Tuple[int, int], Tuple[int, int]]:
230+
if isinstance(layer.quant_method, Fp8LinearMethod):
231+
if layer.quant_method.block_quant is not None:
232+
weight_block_size = \
233+
layer.quant_method.quant_config.weight_block_size
234+
# per-token-group (1, X), block-quantized (X, Y)
235+
return (1, weight_block_size[-1]), weight_block_size
236+
else:
237+
return (-1, -1), (-1, -1) # per-tensor, per-tensor
238+
elif isinstance(layer.quant_method, CompressedTensorsLinearMethod)\
239+
and isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
240+
# this is hacky but we always assume the for
241+
# CompressedTensorsW8A8Fp8 the input is dynamic per-token
242+
# we ignore if it is static-per-tensor since we are going to
243+
# requantize after later anyways
244+
strategy = layer.scheme.strategy
245+
if strategy == QuantizationStrategy.TENSOR:
246+
return (1, -1), (-1, -1) # per-token, per-tensor
247+
elif strategy == QuantizationStrategy.CHANNEL:
248+
return (1, -1), (-1, 1) # per-token, per-channel
249+
else:
250+
raise NotImplementedError(
251+
f"QuantizationStrategy.{strategy} is not supported for "
252+
"fp8 MLA, please run with VLLM_MLA_DISABLE=1")
253+
else:
254+
raise NotImplementedError(
255+
"Can't determine scale group shapes for "
256+
f"{layer.quant_method}, please run with VLLM_MLA_DISABLE=1"
257+
)
258+
259+
def get_scales(layer: LinearBase) -> torch.Tensor:
260+
if hasattr(layer, "weight_scale_inv"):
261+
return layer.weight_scale_inv
262+
return layer.weight_scale
263+
264+
def get_and_maybe_dequant_weights(layer: LinearBase):
265+
if is_layer_fp8(layer):
266+
if isinstance(layer.quant_method, \
267+
CompressedTensorsLinearMethod) and \
268+
isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
269+
# NOTE(lucas): note sure why but `CompressedTensorsW8A8Fp8`
270+
# seems to store weights as (input, output) instead of
271+
# (output, input) so we need to transpose
272+
weight = layer.weight.T # standardize to (output, input)
273+
else:
274+
weight = layer.weight
275+
_, weight_scale_group_shape = \
276+
get_scale_group_shapes_for_fp8(layer)
277+
scales = get_scales(layer)
278+
279+
return scaled_dequantize(weight, scales,
280+
weight_scale_group_shape)
281+
else:
282+
return layer.weight
283+
284+
if not (quantization_scheme_supported(self.kv_b_proj) and\
285+
quantization_scheme_supported(self.q_proj) and\
286+
quantization_scheme_supported(self.o_proj)):
287+
raise NotImplementedError(
288+
"Only FP8 and UnquantizedLinearMethod are supported for MLA"
289+
", please run with VLLM_MLA_DISABLE=1")
290+
291+
weight_dtype = self.kv_b_proj.weight.dtype
292+
assert self.o_proj.weight.dtype == weight_dtype
293+
assert self.q_proj.weight.dtype == weight_dtype
294+
295+
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
184296
assert kv_b_proj_weight.shape == (
185297
self.kv_lora_rank,
186298
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
@@ -198,18 +310,35 @@ def process_weights_after_loading(self):
198310
W_UK, W_UV = kv_b_proj_weight.split(
199311
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
200312

201-
q_proj = self.q_proj.weight.T\
313+
q_proj_weight = get_and_maybe_dequant_weights(self.q_proj).T\
202314
.view(-1, self.num_heads, self.qk_head_dim)
203315

204316
# can be W_Q or W_UQ depending q_lora_rank, the former if
205317
# q_lora_rank is None, the latter otherwise. From the Attention backend
206318
# perspective though we call these both W_Q and rely on the layer
207319
# to pass in the correct matrix
208-
W_Q = q_proj[..., :self.qk_nope_head_dim]
209-
self.W_QR = q_proj[..., self.qk_nope_head_dim:]\
320+
W_Q = q_proj_weight[..., :self.qk_nope_head_dim]
321+
self.W_QR = q_proj_weight[..., self.qk_nope_head_dim:]\
210322
.flatten(start_dim=1).contiguous()
211323

324+
# W_QR is small so for simplicity we dont bother requantizing it
325+
self.W_QR = self.W_QR.to(act_dtype)
326+
212327
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
328+
requantization_enabled = not envs.VLLM_MLA_DISABLE_REQUANTIZATION
329+
if is_fp8(weight_dtype) and requantization_enabled:
330+
# This assumes it wise to requantize using the same group shapes
331+
# (i.e. strategy, per-tensor, per-channel, block etc.) that the
332+
# weights were originally quantized
333+
requant_input_group_shape, requant_weight_group_shape = \
334+
get_scale_group_shapes_for_fp8(self.q_proj)
335+
assert (requant_input_group_shape, requant_weight_group_shape)\
336+
== get_scale_group_shapes_for_fp8(self.kv_b_proj)
337+
assert (requant_input_group_shape, requant_weight_group_shape)\
338+
== get_scale_group_shapes_for_fp8(self.o_proj)
339+
self.reqaunt_input_group_shape = requant_input_group_shape
340+
self.reqaunt_weight_group_shape = requant_weight_group_shape
341+
213342
#
214343
# Perform matrix-absorption following
215344
# https://github.com/flashinfer-ai/flashinfer/pull/551
@@ -223,25 +352,44 @@ def process_weights_after_loading(self):
223352
# latter otherwise
224353
# basically if q_lora_rank is none we are absorbing into q_proj
225354
# instead of UQ
226-
self.W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\
355+
W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\
227356
.flatten(start_dim=1).contiguous()
228357

229-
W_O = self.o_proj.weight\
358+
if is_fp8(weight_dtype) and requantization_enabled:
359+
W_Q_UK, W_Q_UK_scales = scaled_quantize(
360+
W_Q_UK,
361+
self.reqaunt_weight_group_shape,
362+
quant_dtype=current_platform_fp8_dtype)
363+
# For FP8 save the transpose so we can use
364+
# `apply_w8a8_block_fp8_linear` directly
365+
self.W_Q_UK = W_Q_UK.T.contiguous()
366+
self.W_Q_UK_scales = W_Q_UK_scales.T.contiguous()
367+
else:
368+
self.W_Q_UK = W_Q_UK.to(act_dtype)
369+
370+
W_O = get_and_maybe_dequant_weights(self.o_proj)\
230371
.view(-1, self.num_heads, self.v_head_dim)
231-
self.W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\
372+
W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\
232373
.flatten(start_dim=0, end_dim=1).contiguous()
233374

234-
tp_size = get_tensor_model_parallel_world_size()
235-
self.o_proj_absorbed = RowParallelLinear(
236-
self.W_UV_O.shape[0] * tp_size,
237-
self.W_UV_O.shape[1],
238-
bias=False,
239-
# TODO(lucas) figure out how to properly forward quant_method
240-
#quant_config=self.o_proj.quant_method,
241-
)
242-
243-
self.o_proj_absorbed.weight = torch.nn.Parameter(self.W_UV_O.T)
375+
if is_fp8(weight_dtype) and requantization_enabled:
376+
W_UV_O, W_UV_O_scales = scaled_quantize(
377+
W_UV_O,
378+
self.reqaunt_weight_group_shape,
379+
quant_dtype=current_platform_fp8_dtype)
380+
# For FP8 save the transpose so we can use
381+
# `apply_w8a8_block_fp8_linear` directly
382+
self.W_UV_O = W_UV_O.T.contiguous()
383+
self.W_UV_O_scales = W_UV_O_scales.T.contiguous()
384+
else:
385+
self.W_UV_O = W_UV_O.to(act_dtype)
386+
387+
self.tp_size = get_tensor_model_parallel_world_size()
244388
else:
389+
if is_fp8(weight_dtype):
390+
raise NotImplementedError(
391+
"Currently fp8 requires matrix absorption")
392+
245393
self.W_UV = W_UV
246394
self.W_UK = W_UK
247395
self.W_Q = W_Q.flatten(start_dim=1)

vllm/attention/backends/triton_mla.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,12 @@ def get_state_cls() -> Type["TritonMLAState"]:
5757

5858
@staticmethod
5959
def get_kv_cache_shape(
60-
num_blocks: int,
61-
block_size: int,
62-
num_kv_heads: int, # assumed to be 1 for MLA
63-
kv_lora_rank: int, # passed via head_size
60+
num_blocks: int,
61+
block_size: int,
62+
num_kv_heads: int, # assumed to be 1 for MLA
63+
head_size: int,
6464
) -> Tuple[int, ...]:
65-
# TODO(lucas): remove hardcoding k_pe size as 1/8th of kv_lora_rank
66-
k_pe_size = kv_lora_rank // 8
67-
return (num_blocks, block_size, kv_lora_rank + k_pe_size)
65+
return (num_blocks, block_size, head_size)
6866

6967
@staticmethod
7068
def swap_blocks(
@@ -83,7 +81,7 @@ def copy_blocks(
8381

8482
@staticmethod
8583
def get_supported_head_sizes() -> List[int]:
86-
return [512]
84+
return [576]
8785

8886

8987
class TritonMLAState(AttentionState):
@@ -624,8 +622,6 @@ def build(self, seq_lens: List[int], query_lens: List[int],
624622
self.multimodal_placeholder_maps.items()
625623
}
626624

627-
num_kv_splits = 8
628-
629625
return TritonMLAMetadata(
630626
num_prefills=self.num_prefills,
631627
slot_mapping=slot_mapping_tensor,
@@ -645,7 +641,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
645641
context_lens_tensor=context_lens_tensor,
646642
block_tables=block_tables,
647643
use_cuda_graph=use_captured_graph,
648-
num_kv_splits=num_kv_splits,
644+
num_kv_splits=4, # TODO(lucas) add heuristic
649645
head_dim=self.runner.model_config.get_head_size(),
650646
)
651647

vllm/attention/layer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,9 +200,9 @@ def extra_repr(self) -> str:
200200
s += f", backend={self.impl.__class__.__name__}"
201201
return s
202202

203-
def process_weights_after_loading(self):
203+
def process_weights_after_loading(self, act_dtype: torch.dtype):
204204
if hasattr(self.impl, "process_weights_after_loading"):
205-
self.impl.process_weights_after_loading()
205+
self.impl.process_weights_after_loading(act_dtype)
206206

207207

208208
class MultiHeadAttention(nn.Module):

0 commit comments

Comments
 (0)