11from abc import abstractmethod
22from dataclasses import dataclass
3- from typing import Any , Dict , Generic , List , Optional
3+ from typing import Any , Dict , Generic , List , Optional , Tuple
44
55import torch
6+ from compressed_tensors .quantization import QuantizationStrategy
67
78from vllm import _custom_ops as ops
89from vllm import envs
910from vllm .attention .backends .abstract import (AttentionLayer ,
1011 AttentionMetadata ,
1112 MLAAttentionImpl , T )
12- from vllm .distributed import get_tensor_model_parallel_world_size
13+ from vllm .distributed import (get_tensor_model_parallel_world_size ,
14+ tensor_model_parallel_all_reduce )
1315from vllm .model_executor .layers .linear import (ColumnParallelLinear ,
14- RowParallelLinear )
16+ LinearBase , RowParallelLinear ,
17+ UnquantizedLinearMethod )
18+ from vllm .model_executor .layers .quantization .compressed_tensors .compressed_tensors import ( # noqa: E501
19+ CompressedTensorsLinearMethod )
20+ from vllm .model_executor .layers .quantization .compressed_tensors .schemes import (
21+ CompressedTensorsW8A8Fp8 )
22+ from vllm .model_executor .layers .quantization .fp8 import Fp8LinearMethod
23+ from vllm .model_executor .layers .quantization .utils .fp8_utils import (
24+ apply_fp8_linear_generic , current_platform_fp8_dtype , is_fp8 )
25+ from vllm .model_executor .layers .quantization .utils .quant_utils import (
26+ scaled_dequantize , scaled_quantize )
1527from vllm .model_executor .layers .rotary_embedding import RotaryEmbedding
1628from vllm .vllm_flash_attn import flash_attn_varlen_func
1729
@@ -25,11 +37,11 @@ class MLACommonMetadata(AttentionMetadata):
2537
2638class MLACommonImpl (MLAAttentionImpl [T ], Generic [T ]):
2739 """
28- Common class for implementing repeated parts
29-
40+ Common class for implementing repeated parts
41+
3042 Main reference: DeepseekV2 paper, and FlashInfer Implementation
3143 (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
32-
44+
3345 Deepseek's MLA attention works the following way:
3446 * Use a single latent vector to represent the entire KV cache.
3547 * The attention "simulates" a multi-head attention, while the compute is
@@ -46,7 +58,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
4658 * V: V head dim.
4759 * kv_c: latent/compressed KV
4860 * q_c: latent/compressed Q
49-
61+
5062 #
5163 # Outside the MLA attention backend
5264 #
@@ -55,21 +67,21 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
5567 kv_c_k_pe (B, Lkv+R).
5668 2. The kv_c_k_pe is split into kv_c (B, Lkv) and k_pe (B, R). cq
5769 and kv_c are normalized.
58-
70+
5971 #
6072 # Inside the MLA attention backend
6173 #
6274
6375 * if prefill:
64-
65- 3. The q_c is then projected up into the multi-head version.
66- * q_c goes from (B, Lq) to (B, N, (P+R)), which is split into q_nope
67- (B, N, P) and q_pe (B, N, R).
76+
77+ 3. The q_c is then projected up into the multi-head version.
78+ * q_c goes from (B, Lq) to (B, N, (P+R)), which is split into q_nope
79+ (B, N, P) and q_pe (B, N, R).
6880 4. q_pe, k_pe are then passed through rotary embeddings.
6981 5. kv_c and k_pe are concatenated and inserted into the cache
70- 6. The kv_c is then projected up into the multi-head version.
71- * kv_c goes from (B, Lkv) to (B, N, (P+V)) which has the nope
72- dimensions for K and V, which is split into k_nope (B, N, P)
82+ 6. The kv_c is then projected up into the multi-head version.
83+ * kv_c goes from (B, Lkv) to (B, N, (P+V)) which has the nope
84+ dimensions for K and V, which is split into k_nope (B, N, P)
7385 and v (B, N, V).
7486 7. q (B, N, (P+R)) and k (B, N, (P+R)) matrices are assembled from
7587 q_nope, q_pe, k_nope, k_pe.
@@ -112,7 +124,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
112124 From @tsu-bin's calculation, we only want to use the absorption technique
113125 for decode. The prefill algorithm should still use the up-projected MHA
114126 for less flops and memory usage.
115-
127+
116128 """
117129
118130 def __init__ (
@@ -162,15 +174,32 @@ def __init__(
162174
163175 def _v_up_proj_and_o_proj (self , x ):
164176 if envs .VLLM_MLA_PERFORM_MATRIX_ABSORPTION :
165- return self .o_proj_absorbed (
166- x .reshape (- 1 , self .num_heads * self .kv_lora_rank ))[0 ]
177+ if is_fp8 (self .W_UV_O ):
178+ output_parallel = apply_fp8_linear_generic (
179+ x .flatten (start_dim = 1 ), self .W_UV_O , self .W_UV_O_scales ,
180+ self .reqaunt_input_group_shape ,
181+ self .reqaunt_weight_group_shape )
182+ else :
183+ output_parallel = torch .matmul (x .flatten (start_dim = 1 ),
184+ self .W_UV_O )
185+ if self .tp_size > 1 :
186+ output = tensor_model_parallel_all_reduce (output_parallel )
187+ else :
188+ output = output_parallel
189+ return output
167190 else :
168191 x = torch .einsum ("bnl,lnv->bnv" , x , self .W_UV )
169192 return self .o_proj (x .reshape (- 1 ,
170193 self .num_heads * self .v_head_dim ))[0 ]
171194
172195 def _q_proj_and_k_up_proj (self , x ):
173196 if envs .VLLM_MLA_PERFORM_MATRIX_ABSORPTION :
197+ if is_fp8 (self .W_Q_UK ):
198+ return apply_fp8_linear_generic (
199+ x , self .W_Q_UK , self .W_Q_UK_scales ,
200+ self .reqaunt_input_group_shape ,
201+ self .reqaunt_weight_group_shape ).view (
202+ - 1 , self .num_heads , self .kv_lora_rank )
174203 return torch .matmul (x , self .W_Q_UK )\
175204 .view (- 1 , self .num_heads , self .kv_lora_rank )
176205 else :
@@ -179,8 +208,91 @@ def _q_proj_and_k_up_proj(self, x):
179208 return torch .einsum ("bnp,lnp->bnl" , x , self .W_UK )\
180209 .view (- 1 , self .num_heads , self .kv_lora_rank )
181210
182- def process_weights_after_loading (self ):
183- kv_b_proj_weight = self .kv_b_proj .weight .T
211+ def process_weights_after_loading (self , act_dtype : torch .dtype ):
212+
213+ def is_layer_fp8 (layer : LinearBase ) -> bool :
214+ return isinstance (layer .quant_method , Fp8LinearMethod ) or \
215+ (isinstance (layer .quant_method , CompressedTensorsLinearMethod )\
216+ and isinstance (layer .scheme , CompressedTensorsW8A8Fp8 ))
217+
218+ def quantization_scheme_supported (layer : LinearBase ) -> bool :
219+ return isinstance (layer .quant_method , UnquantizedLinearMethod ) or \
220+ is_layer_fp8 (layer )
221+
222+ # TODO(lucas) This is very gross, we need a more wide scale refactor of
223+ # all the FP8 code with a more standard way of
224+ # defining schemes/group-shapes, we should also potentially force
225+ # quant_methods to support a decompress function
226+ #
227+ # returns input_group_shape, weight_group_shape
228+ def get_scale_group_shapes_for_fp8 (layer : LinearBase ) -> \
229+ Tuple [Tuple [int , int ], Tuple [int , int ]]:
230+ if isinstance (layer .quant_method , Fp8LinearMethod ):
231+ if layer .quant_method .block_quant is not None :
232+ weight_block_size = \
233+ layer .quant_method .quant_config .weight_block_size
234+ # per-token-group (1, X), block-quantized (X, Y)
235+ return (1 , weight_block_size [- 1 ]), weight_block_size
236+ else :
237+ return (- 1 , - 1 ), (- 1 , - 1 ) # per-tensor, per-tensor
238+ elif isinstance (layer .quant_method , CompressedTensorsLinearMethod )\
239+ and isinstance (layer .scheme , CompressedTensorsW8A8Fp8 ):
240+ # this is hacky but we always assume the for
241+ # CompressedTensorsW8A8Fp8 the input is dynamic per-token
242+ # we ignore if it is static-per-tensor since we are going to
243+ # requantize after later anyways
244+ strategy = layer .scheme .strategy
245+ if strategy == QuantizationStrategy .TENSOR :
246+ return (1 , - 1 ), (- 1 , - 1 ) # per-token, per-tensor
247+ elif strategy == QuantizationStrategy .CHANNEL :
248+ return (1 , - 1 ), (- 1 , 1 ) # per-token, per-channel
249+ else :
250+ raise NotImplementedError (
251+ f"QuantizationStrategy.{ strategy } is not supported for "
252+ "fp8 MLA, please run with VLLM_MLA_DISABLE=1" )
253+ else :
254+ raise NotImplementedError (
255+ "Can't determine scale group shapes for "
256+ f"{ layer .quant_method } , please run with VLLM_MLA_DISABLE=1"
257+ )
258+
259+ def get_scales (layer : LinearBase ) -> torch .Tensor :
260+ if hasattr (layer , "weight_scale_inv" ):
261+ return layer .weight_scale_inv
262+ return layer .weight_scale
263+
264+ def get_and_maybe_dequant_weights (layer : LinearBase ):
265+ if is_layer_fp8 (layer ):
266+ if isinstance (layer .quant_method , \
267+ CompressedTensorsLinearMethod ) and \
268+ isinstance (layer .scheme , CompressedTensorsW8A8Fp8 ):
269+ # NOTE(lucas): note sure why but `CompressedTensorsW8A8Fp8`
270+ # seems to store weights as (input, output) instead of
271+ # (output, input) so we need to transpose
272+ weight = layer .weight .T # standardize to (output, input)
273+ else :
274+ weight = layer .weight
275+ _ , weight_scale_group_shape = \
276+ get_scale_group_shapes_for_fp8 (layer )
277+ scales = get_scales (layer )
278+
279+ return scaled_dequantize (weight , scales ,
280+ weight_scale_group_shape )
281+ else :
282+ return layer .weight
283+
284+ if not (quantization_scheme_supported (self .kv_b_proj ) and \
285+ quantization_scheme_supported (self .q_proj ) and \
286+ quantization_scheme_supported (self .o_proj )):
287+ raise NotImplementedError (
288+ "Only FP8 and UnquantizedLinearMethod are supported for MLA"
289+ ", please run with VLLM_MLA_DISABLE=1" )
290+
291+ weight_dtype = self .kv_b_proj .weight .dtype
292+ assert self .o_proj .weight .dtype == weight_dtype
293+ assert self .q_proj .weight .dtype == weight_dtype
294+
295+ kv_b_proj_weight = get_and_maybe_dequant_weights (self .kv_b_proj ).T
184296 assert kv_b_proj_weight .shape == (
185297 self .kv_lora_rank ,
186298 self .num_heads * (self .qk_nope_head_dim + self .v_head_dim )), (
@@ -198,18 +310,35 @@ def process_weights_after_loading(self):
198310 W_UK , W_UV = kv_b_proj_weight .split (
199311 [self .qk_nope_head_dim , self .v_head_dim ], dim = - 1 )
200312
201- q_proj = self .q_proj . weight .T \
313+ q_proj_weight = get_and_maybe_dequant_weights ( self .q_proj ) .T \
202314 .view (- 1 , self .num_heads , self .qk_head_dim )
203315
204316 # can be W_Q or W_UQ depending q_lora_rank, the former if
205317 # q_lora_rank is None, the latter otherwise. From the Attention backend
206318 # perspective though we call these both W_Q and rely on the layer
207319 # to pass in the correct matrix
208- W_Q = q_proj [..., :self .qk_nope_head_dim ]
209- self .W_QR = q_proj [..., self .qk_nope_head_dim :]\
320+ W_Q = q_proj_weight [..., :self .qk_nope_head_dim ]
321+ self .W_QR = q_proj_weight [..., self .qk_nope_head_dim :]\
210322 .flatten (start_dim = 1 ).contiguous ()
211323
324+ # W_QR is small so for simplicity we dont bother requantizing it
325+ self .W_QR = self .W_QR .to (act_dtype )
326+
212327 if envs .VLLM_MLA_PERFORM_MATRIX_ABSORPTION :
328+ requantization_enabled = not envs .VLLM_MLA_DISABLE_REQUANTIZATION
329+ if is_fp8 (weight_dtype ) and requantization_enabled :
330+ # This assumes it wise to requantize using the same group shapes
331+ # (i.e. strategy, per-tensor, per-channel, block etc.) that the
332+ # weights were originally quantized
333+ requant_input_group_shape , requant_weight_group_shape = \
334+ get_scale_group_shapes_for_fp8 (self .q_proj )
335+ assert (requant_input_group_shape , requant_weight_group_shape )\
336+ == get_scale_group_shapes_for_fp8 (self .kv_b_proj )
337+ assert (requant_input_group_shape , requant_weight_group_shape )\
338+ == get_scale_group_shapes_for_fp8 (self .o_proj )
339+ self .reqaunt_input_group_shape = requant_input_group_shape
340+ self .reqaunt_weight_group_shape = requant_weight_group_shape
341+
213342 #
214343 # Perform matrix-absorption following
215344 # https://github.com/flashinfer-ai/flashinfer/pull/551
@@ -223,25 +352,44 @@ def process_weights_after_loading(self):
223352 # latter otherwise
224353 # basically if q_lora_rank is none we are absorbing into q_proj
225354 # instead of UQ
226- self . W_Q_UK = torch .einsum ("qnd,lnd -> qnl" , W_Q , W_UK )\
355+ W_Q_UK = torch .einsum ("qnd,lnd -> qnl" , W_Q , W_UK )\
227356 .flatten (start_dim = 1 ).contiguous ()
228357
229- W_O = self .o_proj .weight \
358+ if is_fp8 (weight_dtype ) and requantization_enabled :
359+ W_Q_UK , W_Q_UK_scales = scaled_quantize (
360+ W_Q_UK ,
361+ self .reqaunt_weight_group_shape ,
362+ quant_dtype = current_platform_fp8_dtype )
363+ # For FP8 save the transpose so we can use
364+ # `apply_w8a8_block_fp8_linear` directly
365+ self .W_Q_UK = W_Q_UK .T .contiguous ()
366+ self .W_Q_UK_scales = W_Q_UK_scales .T .contiguous ()
367+ else :
368+ self .W_Q_UK = W_Q_UK .to (act_dtype )
369+
370+ W_O = get_and_maybe_dequant_weights (self .o_proj )\
230371 .view (- 1 , self .num_heads , self .v_head_dim )
231- self . W_UV_O = torch .einsum ("lnd,hnd -> nlh" , W_UV , W_O )\
372+ W_UV_O = torch .einsum ("lnd,hnd -> nlh" , W_UV , W_O )\
232373 .flatten (start_dim = 0 , end_dim = 1 ).contiguous ()
233374
234- tp_size = get_tensor_model_parallel_world_size ()
235- self .o_proj_absorbed = RowParallelLinear (
236- self .W_UV_O .shape [0 ] * tp_size ,
237- self .W_UV_O .shape [1 ],
238- bias = False ,
239- # TODO(lucas) figure out how to properly forward quant_method
240- #quant_config=self.o_proj.quant_method,
241- )
242-
243- self .o_proj_absorbed .weight = torch .nn .Parameter (self .W_UV_O .T )
375+ if is_fp8 (weight_dtype ) and requantization_enabled :
376+ W_UV_O , W_UV_O_scales = scaled_quantize (
377+ W_UV_O ,
378+ self .reqaunt_weight_group_shape ,
379+ quant_dtype = current_platform_fp8_dtype )
380+ # For FP8 save the transpose so we can use
381+ # `apply_w8a8_block_fp8_linear` directly
382+ self .W_UV_O = W_UV_O .T .contiguous ()
383+ self .W_UV_O_scales = W_UV_O_scales .T .contiguous ()
384+ else :
385+ self .W_UV_O = W_UV_O .to (act_dtype )
386+
387+ self .tp_size = get_tensor_model_parallel_world_size ()
244388 else :
389+ if is_fp8 (weight_dtype ):
390+ raise NotImplementedError (
391+ "Currently fp8 requires matrix absorption" )
392+
245393 self .W_UV = W_UV
246394 self .W_UK = W_UK
247395 self .W_Q = W_Q .flatten (start_dim = 1 )
0 commit comments