1
1
from abc import abstractmethod
2
2
from dataclasses import dataclass
3
- from typing import Any , Dict , Generic , List , Optional
3
+ from typing import Any , Dict , Generic , List , Optional , Tuple
4
4
5
5
import torch
6
+ from compressed_tensors .quantization import QuantizationStrategy
6
7
7
8
from vllm import _custom_ops as ops
8
9
from vllm import envs
9
10
from vllm .attention .backends .abstract import (AttentionLayer ,
10
11
AttentionMetadata ,
11
12
MLAAttentionImpl , T )
12
- from vllm .distributed import get_tensor_model_parallel_world_size
13
+ from vllm .distributed import (get_tensor_model_parallel_world_size ,
14
+ tensor_model_parallel_all_reduce )
13
15
from vllm .model_executor .layers .linear import (ColumnParallelLinear ,
14
- RowParallelLinear )
16
+ LinearBase , RowParallelLinear ,
17
+ UnquantizedLinearMethod )
18
+ from vllm .model_executor .layers .quantization .compressed_tensors .compressed_tensors import ( # noqa: E501
19
+ CompressedTensorsLinearMethod )
20
+ from vllm .model_executor .layers .quantization .compressed_tensors .schemes import (
21
+ CompressedTensorsW8A8Fp8 )
22
+ from vllm .model_executor .layers .quantization .fp8 import Fp8LinearMethod
23
+ from vllm .model_executor .layers .quantization .utils .fp8_utils import (
24
+ apply_fp8_linear_generic , current_platform_fp8_dtype , is_fp8 )
25
+ from vllm .model_executor .layers .quantization .utils .quant_utils import (
26
+ scaled_dequantize , scaled_quantize )
15
27
from vllm .model_executor .layers .rotary_embedding import RotaryEmbedding
16
28
from vllm .vllm_flash_attn import flash_attn_varlen_func
17
29
@@ -25,11 +37,11 @@ class MLACommonMetadata(AttentionMetadata):
25
37
26
38
class MLACommonImpl (MLAAttentionImpl [T ], Generic [T ]):
27
39
"""
28
- Common class for implementing repeated parts
29
-
40
+ Common class for implementing repeated parts
41
+
30
42
Main reference: DeepseekV2 paper, and FlashInfer Implementation
31
43
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
32
-
44
+
33
45
Deepseek's MLA attention works the following way:
34
46
* Use a single latent vector to represent the entire KV cache.
35
47
* The attention "simulates" a multi-head attention, while the compute is
@@ -46,7 +58,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
46
58
* V: V head dim.
47
59
* kv_c: latent/compressed KV
48
60
* q_c: latent/compressed Q
49
-
61
+
50
62
#
51
63
# Outside the MLA attention backend
52
64
#
@@ -55,21 +67,21 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
55
67
kv_c_k_pe (B, Lkv+R).
56
68
2. The kv_c_k_pe is split into kv_c (B, Lkv) and k_pe (B, R). cq
57
69
and kv_c are normalized.
58
-
70
+
59
71
#
60
72
# Inside the MLA attention backend
61
73
#
62
74
63
75
* if prefill:
64
-
65
- 3. The q_c is then projected up into the multi-head version.
66
- * q_c goes from (B, Lq) to (B, N, (P+R)), which is split into q_nope
67
- (B, N, P) and q_pe (B, N, R).
76
+
77
+ 3. The q_c is then projected up into the multi-head version.
78
+ * q_c goes from (B, Lq) to (B, N, (P+R)), which is split into q_nope
79
+ (B, N, P) and q_pe (B, N, R).
68
80
4. q_pe, k_pe are then passed through rotary embeddings.
69
81
5. kv_c and k_pe are concatenated and inserted into the cache
70
- 6. The kv_c is then projected up into the multi-head version.
71
- * kv_c goes from (B, Lkv) to (B, N, (P+V)) which has the nope
72
- dimensions for K and V, which is split into k_nope (B, N, P)
82
+ 6. The kv_c is then projected up into the multi-head version.
83
+ * kv_c goes from (B, Lkv) to (B, N, (P+V)) which has the nope
84
+ dimensions for K and V, which is split into k_nope (B, N, P)
73
85
and v (B, N, V).
74
86
7. q (B, N, (P+R)) and k (B, N, (P+R)) matrices are assembled from
75
87
q_nope, q_pe, k_nope, k_pe.
@@ -112,7 +124,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
112
124
From @tsu-bin's calculation, we only want to use the absorption technique
113
125
for decode. The prefill algorithm should still use the up-projected MHA
114
126
for less flops and memory usage.
115
-
127
+
116
128
"""
117
129
118
130
def __init__ (
@@ -162,15 +174,32 @@ def __init__(
162
174
163
175
def _v_up_proj_and_o_proj (self , x ):
164
176
if envs .VLLM_MLA_PERFORM_MATRIX_ABSORPTION :
165
- return self .o_proj_absorbed (
166
- x .reshape (- 1 , self .num_heads * self .kv_lora_rank ))[0 ]
177
+ if is_fp8 (self .W_UV_O ):
178
+ output_parallel = apply_fp8_linear_generic (
179
+ x .flatten (start_dim = 1 ), self .W_UV_O , self .W_UV_O_scales ,
180
+ self .reqaunt_input_group_shape ,
181
+ self .reqaunt_weight_group_shape )
182
+ else :
183
+ output_parallel = torch .matmul (x .flatten (start_dim = 1 ),
184
+ self .W_UV_O )
185
+ if self .tp_size > 1 :
186
+ output = tensor_model_parallel_all_reduce (output_parallel )
187
+ else :
188
+ output = output_parallel
189
+ return output
167
190
else :
168
191
x = torch .einsum ("bnl,lnv->bnv" , x , self .W_UV )
169
192
return self .o_proj (x .reshape (- 1 ,
170
193
self .num_heads * self .v_head_dim ))[0 ]
171
194
172
195
def _q_proj_and_k_up_proj (self , x ):
173
196
if envs .VLLM_MLA_PERFORM_MATRIX_ABSORPTION :
197
+ if is_fp8 (self .W_Q_UK ):
198
+ return apply_fp8_linear_generic (
199
+ x , self .W_Q_UK , self .W_Q_UK_scales ,
200
+ self .reqaunt_input_group_shape ,
201
+ self .reqaunt_weight_group_shape ).view (
202
+ - 1 , self .num_heads , self .kv_lora_rank )
174
203
return torch .matmul (x , self .W_Q_UK )\
175
204
.view (- 1 , self .num_heads , self .kv_lora_rank )
176
205
else :
@@ -179,8 +208,91 @@ def _q_proj_and_k_up_proj(self, x):
179
208
return torch .einsum ("bnp,lnp->bnl" , x , self .W_UK )\
180
209
.view (- 1 , self .num_heads , self .kv_lora_rank )
181
210
182
- def process_weights_after_loading (self ):
183
- kv_b_proj_weight = self .kv_b_proj .weight .T
211
+ def process_weights_after_loading (self , act_dtype : torch .dtype ):
212
+
213
+ def is_layer_fp8 (layer : LinearBase ) -> bool :
214
+ return isinstance (layer .quant_method , Fp8LinearMethod ) or \
215
+ (isinstance (layer .quant_method , CompressedTensorsLinearMethod )\
216
+ and isinstance (layer .scheme , CompressedTensorsW8A8Fp8 ))
217
+
218
+ def quantization_scheme_supported (layer : LinearBase ) -> bool :
219
+ return isinstance (layer .quant_method , UnquantizedLinearMethod ) or \
220
+ is_layer_fp8 (layer )
221
+
222
+ # TODO(lucas) This is very gross, we need a more wide scale refactor of
223
+ # all the FP8 code with a more standard way of
224
+ # defining schemes/group-shapes, we should also potentially force
225
+ # quant_methods to support a decompress function
226
+ #
227
+ # returns input_group_shape, weight_group_shape
228
+ def get_scale_group_shapes_for_fp8 (layer : LinearBase ) -> \
229
+ Tuple [Tuple [int , int ], Tuple [int , int ]]:
230
+ if isinstance (layer .quant_method , Fp8LinearMethod ):
231
+ if layer .quant_method .block_quant is not None :
232
+ weight_block_size = \
233
+ layer .quant_method .quant_config .weight_block_size
234
+ # per-token-group (1, X), block-quantized (X, Y)
235
+ return (1 , weight_block_size [- 1 ]), weight_block_size
236
+ else :
237
+ return (- 1 , - 1 ), (- 1 , - 1 ) # per-tensor, per-tensor
238
+ elif isinstance (layer .quant_method , CompressedTensorsLinearMethod )\
239
+ and isinstance (layer .scheme , CompressedTensorsW8A8Fp8 ):
240
+ # this is hacky but we always assume the for
241
+ # CompressedTensorsW8A8Fp8 the input is dynamic per-token
242
+ # we ignore if it is static-per-tensor since we are going to
243
+ # requantize after later anyways
244
+ strategy = layer .scheme .strategy
245
+ if strategy == QuantizationStrategy .TENSOR :
246
+ return (1 , - 1 ), (- 1 , - 1 ) # per-token, per-tensor
247
+ elif strategy == QuantizationStrategy .CHANNEL :
248
+ return (1 , - 1 ), (- 1 , 1 ) # per-token, per-channel
249
+ else :
250
+ raise NotImplementedError (
251
+ f"QuantizationStrategy.{ strategy } is not supported for "
252
+ "fp8 MLA, please run with VLLM_MLA_DISABLE=1" )
253
+ else :
254
+ raise NotImplementedError (
255
+ "Can't determine scale group shapes for "
256
+ f"{ layer .quant_method } , please run with VLLM_MLA_DISABLE=1"
257
+ )
258
+
259
+ def get_scales (layer : LinearBase ) -> torch .Tensor :
260
+ if hasattr (layer , "weight_scale_inv" ):
261
+ return layer .weight_scale_inv
262
+ return layer .weight_scale
263
+
264
+ def get_and_maybe_dequant_weights (layer : LinearBase ):
265
+ if is_layer_fp8 (layer ):
266
+ if isinstance (layer .quant_method , \
267
+ CompressedTensorsLinearMethod ) and \
268
+ isinstance (layer .scheme , CompressedTensorsW8A8Fp8 ):
269
+ # NOTE(lucas): note sure why but `CompressedTensorsW8A8Fp8`
270
+ # seems to store weights as (input, output) instead of
271
+ # (output, input) so we need to transpose
272
+ weight = layer .weight .T # standardize to (output, input)
273
+ else :
274
+ weight = layer .weight
275
+ _ , weight_scale_group_shape = \
276
+ get_scale_group_shapes_for_fp8 (layer )
277
+ scales = get_scales (layer )
278
+
279
+ return scaled_dequantize (weight , scales ,
280
+ weight_scale_group_shape )
281
+ else :
282
+ return layer .weight
283
+
284
+ if not (quantization_scheme_supported (self .kv_b_proj ) and \
285
+ quantization_scheme_supported (self .q_proj ) and \
286
+ quantization_scheme_supported (self .o_proj )):
287
+ raise NotImplementedError (
288
+ "Only FP8 and UnquantizedLinearMethod are supported for MLA"
289
+ ", please run with VLLM_MLA_DISABLE=1" )
290
+
291
+ weight_dtype = self .kv_b_proj .weight .dtype
292
+ assert self .o_proj .weight .dtype == weight_dtype
293
+ assert self .q_proj .weight .dtype == weight_dtype
294
+
295
+ kv_b_proj_weight = get_and_maybe_dequant_weights (self .kv_b_proj ).T
184
296
assert kv_b_proj_weight .shape == (
185
297
self .kv_lora_rank ,
186
298
self .num_heads * (self .qk_nope_head_dim + self .v_head_dim )), (
@@ -198,18 +310,35 @@ def process_weights_after_loading(self):
198
310
W_UK , W_UV = kv_b_proj_weight .split (
199
311
[self .qk_nope_head_dim , self .v_head_dim ], dim = - 1 )
200
312
201
- q_proj = self .q_proj . weight .T \
313
+ q_proj_weight = get_and_maybe_dequant_weights ( self .q_proj ) .T \
202
314
.view (- 1 , self .num_heads , self .qk_head_dim )
203
315
204
316
# can be W_Q or W_UQ depending q_lora_rank, the former if
205
317
# q_lora_rank is None, the latter otherwise. From the Attention backend
206
318
# perspective though we call these both W_Q and rely on the layer
207
319
# to pass in the correct matrix
208
- W_Q = q_proj [..., :self .qk_nope_head_dim ]
209
- self .W_QR = q_proj [..., self .qk_nope_head_dim :]\
320
+ W_Q = q_proj_weight [..., :self .qk_nope_head_dim ]
321
+ self .W_QR = q_proj_weight [..., self .qk_nope_head_dim :]\
210
322
.flatten (start_dim = 1 ).contiguous ()
211
323
324
+ # W_QR is small so for simplicity we dont bother requantizing it
325
+ self .W_QR = self .W_QR .to (act_dtype )
326
+
212
327
if envs .VLLM_MLA_PERFORM_MATRIX_ABSORPTION :
328
+ requantization_enabled = not envs .VLLM_MLA_DISABLE_REQUANTIZATION
329
+ if is_fp8 (weight_dtype ) and requantization_enabled :
330
+ # This assumes it wise to requantize using the same group shapes
331
+ # (i.e. strategy, per-tensor, per-channel, block etc.) that the
332
+ # weights were originally quantized
333
+ requant_input_group_shape , requant_weight_group_shape = \
334
+ get_scale_group_shapes_for_fp8 (self .q_proj )
335
+ assert (requant_input_group_shape , requant_weight_group_shape )\
336
+ == get_scale_group_shapes_for_fp8 (self .kv_b_proj )
337
+ assert (requant_input_group_shape , requant_weight_group_shape )\
338
+ == get_scale_group_shapes_for_fp8 (self .o_proj )
339
+ self .reqaunt_input_group_shape = requant_input_group_shape
340
+ self .reqaunt_weight_group_shape = requant_weight_group_shape
341
+
213
342
#
214
343
# Perform matrix-absorption following
215
344
# https://github.com/flashinfer-ai/flashinfer/pull/551
@@ -223,25 +352,44 @@ def process_weights_after_loading(self):
223
352
# latter otherwise
224
353
# basically if q_lora_rank is none we are absorbing into q_proj
225
354
# instead of UQ
226
- self . W_Q_UK = torch .einsum ("qnd,lnd -> qnl" , W_Q , W_UK )\
355
+ W_Q_UK = torch .einsum ("qnd,lnd -> qnl" , W_Q , W_UK )\
227
356
.flatten (start_dim = 1 ).contiguous ()
228
357
229
- W_O = self .o_proj .weight \
358
+ if is_fp8 (weight_dtype ) and requantization_enabled :
359
+ W_Q_UK , W_Q_UK_scales = scaled_quantize (
360
+ W_Q_UK ,
361
+ self .reqaunt_weight_group_shape ,
362
+ quant_dtype = current_platform_fp8_dtype )
363
+ # For FP8 save the transpose so we can use
364
+ # `apply_w8a8_block_fp8_linear` directly
365
+ self .W_Q_UK = W_Q_UK .T .contiguous ()
366
+ self .W_Q_UK_scales = W_Q_UK_scales .T .contiguous ()
367
+ else :
368
+ self .W_Q_UK = W_Q_UK .to (act_dtype )
369
+
370
+ W_O = get_and_maybe_dequant_weights (self .o_proj )\
230
371
.view (- 1 , self .num_heads , self .v_head_dim )
231
- self . W_UV_O = torch .einsum ("lnd,hnd -> nlh" , W_UV , W_O )\
372
+ W_UV_O = torch .einsum ("lnd,hnd -> nlh" , W_UV , W_O )\
232
373
.flatten (start_dim = 0 , end_dim = 1 ).contiguous ()
233
374
234
- tp_size = get_tensor_model_parallel_world_size ()
235
- self .o_proj_absorbed = RowParallelLinear (
236
- self .W_UV_O .shape [0 ] * tp_size ,
237
- self .W_UV_O .shape [1 ],
238
- bias = False ,
239
- # TODO(lucas) figure out how to properly forward quant_method
240
- #quant_config=self.o_proj.quant_method,
241
- )
242
-
243
- self .o_proj_absorbed .weight = torch .nn .Parameter (self .W_UV_O .T )
375
+ if is_fp8 (weight_dtype ) and requantization_enabled :
376
+ W_UV_O , W_UV_O_scales = scaled_quantize (
377
+ W_UV_O ,
378
+ self .reqaunt_weight_group_shape ,
379
+ quant_dtype = current_platform_fp8_dtype )
380
+ # For FP8 save the transpose so we can use
381
+ # `apply_w8a8_block_fp8_linear` directly
382
+ self .W_UV_O = W_UV_O .T .contiguous ()
383
+ self .W_UV_O_scales = W_UV_O_scales .T .contiguous ()
384
+ else :
385
+ self .W_UV_O = W_UV_O .to (act_dtype )
386
+
387
+ self .tp_size = get_tensor_model_parallel_world_size ()
244
388
else :
389
+ if is_fp8 (weight_dtype ):
390
+ raise NotImplementedError (
391
+ "Currently fp8 requires matrix absorption" )
392
+
245
393
self .W_UV = W_UV
246
394
self .W_UK = W_UK
247
395
self .W_Q = W_Q .flatten (start_dim = 1 )
0 commit comments