5
5
6
6
from vllm .model_executor .layers .linear import LinearBase , LinearMethodBase
7
7
from vllm .model_executor .layers .quantization .base_config import ( # noqa: E501
8
- QuantizationConfig )
8
+ QuantizationConfig , QuantizeMethodBase )
9
9
from vllm .model_executor .layers .quantization .compressed_tensors .schemes import (
10
10
W4A16SPARSE24_SUPPORTED_BITS , WNA16_SUPPORTED_BITS ,
11
11
CompressedTensorsScheme , CompressedTensorsUnquantized ,
15
15
CompressionFormat , QuantizationArgs , QuantizationStrategy ,
16
16
QuantizationType , find_matched_target , is_activation_quantization_format ,
17
17
should_ignore_layer )
18
+ from vllm .model_executor .layers .quantization .kv_cache import BaseKVCacheMethod
18
19
from vllm .platforms import current_platform
19
20
20
21
21
22
class CompressedTensorsConfig (QuantizationConfig ):
22
23
23
- def __init__ (self , target_scheme_map : Dict [str , Any ], ignore : List [str ],
24
- quant_format : str ):
24
+ def __init__ (self ,
25
+ target_scheme_map : Dict [str , Any ],
26
+ ignore : List [str ],
27
+ quant_format : str ,
28
+ kv_cache_scheme : Optional [Dict [str , Any ]] = None ):
25
29
26
30
self .ignore = ignore
27
31
self .quant_format = quant_format
28
32
# Map from [target -> scheme]
29
33
self .target_scheme_map = target_scheme_map
34
+ self .kv_cache_scheme = kv_cache_scheme
30
35
31
36
def get_linear_method (self ) -> "CompressedTensorsLinearMethod" :
32
37
return CompressedTensorsLinearMethod (self )
@@ -50,9 +55,12 @@ def get_quant_method(
50
55
self ,
51
56
layer : torch .nn .Module ,
52
57
prefix : str ,
53
- ) -> Optional ["CompressedTensorsLinearMethod" ]:
58
+ ) -> Optional ["QuantizeMethodBase" ]:
59
+ from vllm .attention .layer import Attention # Avoid circular import
54
60
if isinstance (layer , LinearBase ):
55
61
return CompressedTensorsLinearMethod (self )
62
+ if isinstance (layer , Attention ):
63
+ return CompressedTensorsKVCacheMethod (self )
56
64
return None
57
65
58
66
@classmethod
@@ -85,7 +93,8 @@ def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
85
93
86
94
return cls (target_scheme_map = target_scheme_map ,
87
95
ignore = ignore ,
88
- quant_format = quant_format )
96
+ quant_format = quant_format ,
97
+ kv_cache_scheme = config .get ("kv_cache_scheme" ))
89
98
90
99
@classmethod
91
100
def get_config_filenames (cls ) -> List [str ]:
@@ -309,3 +318,47 @@ def apply(self,
309
318
if scheme is None :
310
319
raise ValueError ("A scheme must be defined for each layer" )
311
320
return scheme .apply_weights (layer , x , bias = bias )
321
+
322
+
323
+ class CompressedTensorsKVCacheMethod (BaseKVCacheMethod ):
324
+ """
325
+ Supports loading kv-cache scaling factors from compressed-tensors
326
+ checkpoints.
327
+ """
328
+
329
+ def __init__ (self , quant_config : CompressedTensorsConfig ):
330
+ self .validate_kv_cache_scheme (quant_config .kv_cache_scheme )
331
+ super ().__init__ (quant_config )
332
+
333
+ @staticmethod
334
+ def validate_kv_cache_scheme (kv_cache_scheme : Optional [Dict [str , Any ]]):
335
+ """
336
+ Validator for the kv cache scheme. Useful for controlling the
337
+ kv cache quantization schemes, that are being supported in vLLM
338
+ :param kv_cache_scheme: the compressed-tensors kv cache scheme
339
+ """
340
+ if kv_cache_scheme is None :
341
+ return
342
+
343
+ type_ = kv_cache_scheme .get ("type" )
344
+ num_bits = kv_cache_scheme .get ("num_bits" )
345
+
346
+ if type_ != "float" and num_bits != 8 :
347
+ raise NotImplementedError (
348
+ "Currently supported kv cache quantization is "
349
+ "num_bits=8, type=float, however "
350
+ f"received num_bits={ num_bits } , type={ type_ } " )
351
+
352
+ strategy = kv_cache_scheme .get ("strategy" )
353
+ if strategy != "tensor" :
354
+ raise NotImplementedError (
355
+ "Only support per-tensor scaling factor "
356
+ "for compressed-tensors KV cache. "
357
+ f"Expected strategy: tensor, found strategy: { strategy } " )
358
+
359
+ is_symmetric = kv_cache_scheme .get ("symmetric" )
360
+ if not is_symmetric :
361
+ raise NotImplementedError (
362
+ "Only support symmetric scaling factor "
363
+ "for compressed-tensors KV cache. "
364
+ f"However found symmetric: { is_symmetric } " )
0 commit comments