Skip to content

Commit 9e0b558

Browse files
authored
[Misc] Support FP8 kv cache scales from compressed-tensors (#6528)
1 parent e519ae0 commit 9e0b558

File tree

7 files changed

+186
-75
lines changed

7 files changed

+186
-75
lines changed

tests/quantization/test_compressed_tensors.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,10 @@ def test_compressed_tensors_fp8(vllm_runner):
150150

151151
output = llm.generate_greedy("Hello my name is", max_tokens=20)
152152
assert output
153+
154+
155+
def test_compressed_tensors_kv_cache(vllm_runner):
156+
model_path = "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme"
157+
with vllm_runner(model_path, kv_cache_dtype="fp8") as llm:
158+
output = llm.generate_greedy("Hello world!", max_tokens=20)
159+
assert output

vllm/attention/layer.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from vllm.config import CacheConfig
1010
from vllm.model_executor.layers.quantization.base_config import (
1111
QuantizationConfig)
12-
from vllm.model_executor.layers.quantization.fp8 import Fp8KVCacheMethod
12+
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
1313

1414

1515
class Attention(nn.Module):
@@ -59,19 +59,18 @@ def __init__(
5959
quant_method = quant_config.get_quant_method(
6060
self, prefix=prefix) if quant_config else None
6161
if quant_method is not None:
62-
assert isinstance(quant_method, Fp8KVCacheMethod)
62+
assert isinstance(quant_method, BaseKVCacheMethod)
6363
# TODO (mgoin): kv cache dtype should be specified in the FP8
6464
# checkpoint config and become the "auto" behavior
65-
if "fp8" in self.kv_cache_dtype:
66-
if self.kv_cache_dtype == "fp8_e5m2":
67-
raise ValueError("fp8_e5m2 kv-cache is not supported with "
68-
"fp8 checkpoints.")
69-
# When FP8 quantization is enabled, we make a parameter
70-
# "kv_scale" so that it can be loaded from FP8 checkpoint.
71-
# The k/v_scale will then be converted back to
72-
# self._kv_scale in a native float32 value after weight loading
73-
self.quant_method = quant_method
74-
self.quant_method.create_weights(self)
65+
if self.kv_cache_dtype == "fp8_e5m2":
66+
raise ValueError("fp8_e5m2 kv-cache is not supported with "
67+
"fp8 checkpoints.")
68+
# If quantization is enabled, we make "k_scale" and "v_scale"
69+
# parameters so that it can be loaded from the model checkpoint.
70+
# The k/v_scale will then be converted back to native float32
71+
# values after weight loading.
72+
self.quant_method = quant_method
73+
self.quant_method.create_weights(self)
7574

7675
# During model initialization, the default dtype is set as the model
7776
# weight and activation dtype.

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
77
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
8-
QuantizationConfig)
8+
QuantizationConfig, QuantizeMethodBase)
99
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
1010
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS,
1111
CompressedTensorsScheme, CompressedTensorsUnquantized,
@@ -15,18 +15,23 @@
1515
CompressionFormat, QuantizationArgs, QuantizationStrategy,
1616
QuantizationType, find_matched_target, is_activation_quantization_format,
1717
should_ignore_layer)
18+
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
1819
from vllm.platforms import current_platform
1920

2021

2122
class CompressedTensorsConfig(QuantizationConfig):
2223

23-
def __init__(self, target_scheme_map: Dict[str, Any], ignore: List[str],
24-
quant_format: str):
24+
def __init__(self,
25+
target_scheme_map: Dict[str, Any],
26+
ignore: List[str],
27+
quant_format: str,
28+
kv_cache_scheme: Optional[Dict[str, Any]] = None):
2529

2630
self.ignore = ignore
2731
self.quant_format = quant_format
2832
# Map from [target -> scheme]
2933
self.target_scheme_map = target_scheme_map
34+
self.kv_cache_scheme = kv_cache_scheme
3035

3136
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
3237
return CompressedTensorsLinearMethod(self)
@@ -50,9 +55,12 @@ def get_quant_method(
5055
self,
5156
layer: torch.nn.Module,
5257
prefix: str,
53-
) -> Optional["CompressedTensorsLinearMethod"]:
58+
) -> Optional["QuantizeMethodBase"]:
59+
from vllm.attention.layer import Attention # Avoid circular import
5460
if isinstance(layer, LinearBase):
5561
return CompressedTensorsLinearMethod(self)
62+
if isinstance(layer, Attention):
63+
return CompressedTensorsKVCacheMethod(self)
5664
return None
5765

5866
@classmethod
@@ -85,7 +93,8 @@ def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
8593

8694
return cls(target_scheme_map=target_scheme_map,
8795
ignore=ignore,
88-
quant_format=quant_format)
96+
quant_format=quant_format,
97+
kv_cache_scheme=config.get("kv_cache_scheme"))
8998

9099
@classmethod
91100
def get_config_filenames(cls) -> List[str]:
@@ -309,3 +318,47 @@ def apply(self,
309318
if scheme is None:
310319
raise ValueError("A scheme must be defined for each layer")
311320
return scheme.apply_weights(layer, x, bias=bias)
321+
322+
323+
class CompressedTensorsKVCacheMethod(BaseKVCacheMethod):
324+
"""
325+
Supports loading kv-cache scaling factors from compressed-tensors
326+
checkpoints.
327+
"""
328+
329+
def __init__(self, quant_config: CompressedTensorsConfig):
330+
self.validate_kv_cache_scheme(quant_config.kv_cache_scheme)
331+
super().__init__(quant_config)
332+
333+
@staticmethod
334+
def validate_kv_cache_scheme(kv_cache_scheme: Optional[Dict[str, Any]]):
335+
"""
336+
Validator for the kv cache scheme. Useful for controlling the
337+
kv cache quantization schemes, that are being supported in vLLM
338+
:param kv_cache_scheme: the compressed-tensors kv cache scheme
339+
"""
340+
if kv_cache_scheme is None:
341+
return
342+
343+
type_ = kv_cache_scheme.get("type")
344+
num_bits = kv_cache_scheme.get("num_bits")
345+
346+
if type_ != "float" and num_bits != 8:
347+
raise NotImplementedError(
348+
"Currently supported kv cache quantization is "
349+
"num_bits=8, type=float, however "
350+
f"received num_bits={num_bits}, type={type_}")
351+
352+
strategy = kv_cache_scheme.get("strategy")
353+
if strategy != "tensor":
354+
raise NotImplementedError(
355+
"Only support per-tensor scaling factor "
356+
"for compressed-tensors KV cache. "
357+
f"Expected strategy: tensor, found strategy: {strategy}")
358+
359+
is_symmetric = kv_cache_scheme.get("symmetric")
360+
if not is_symmetric:
361+
raise NotImplementedError(
362+
"Only support symmetric scaling factor "
363+
"for compressed-tensors KV cache. "
364+
f"However found symmetric: {is_symmetric}")

vllm/model_executor/layers/quantization/compressed_tensors/utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,23 @@ def _find_first_match(value: str,
209209
return None
210210

211211

212+
def get_compressed_tensors_cache_scale(name: str) -> Optional[str]:
213+
"""
214+
Check whether the param name matches the format for k/v cache scales
215+
in compressed-tensors. If this is the case, return its equivalent
216+
param name expected by vLLM
217+
218+
:param name: param name
219+
:return: matching param name for KV cache scale in vLLM
220+
"""
221+
if name.endswith(".output_scale") and ".k_proj" in name:
222+
return name.replace(".k_proj.output_scale", ".attn.k_scale")
223+
if name.endswith(".output_scale") and ".v_proj" in name:
224+
return name.replace(".v_proj.output_scale", ".attn.v_scale")
225+
# If no matches, return None
226+
return None
227+
228+
212229
def _is_equal_or_regex_match(value: str,
213230
target: str,
214231
check_contains: bool = False) -> bool:

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 5 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
1212
from vllm.model_executor.layers.quantization.base_config import (
1313
QuantizationConfig, QuantizeMethodBase)
14+
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
1415
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
1516
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
1617
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
@@ -400,64 +401,10 @@ def apply(self,
400401
topk_group=topk_group)
401402

402403

403-
class Fp8KVCacheMethod(QuantizeMethodBase):
404-
"""Supports loading kv-cache scaling factors from FP8 checkpoints.
404+
class Fp8KVCacheMethod(BaseKVCacheMethod):
405+
"""
406+
Supports loading kv-cache scaling factors from FP8 checkpoints.
405407
"""
406408

407409
def __init__(self, quant_config: Fp8Config):
408-
self.quant_config = quant_config
409-
410-
def create_weights(self, layer: torch.nn.Module):
411-
"""Create "weight" (aka k_scale and v_scale) for an attention layer.
412-
413-
Args:
414-
layer: The layer that is using the QuantizeMethodBase factory.
415-
"""
416-
# Initialize the KV cache scales to -1.0, which is an invalid value.
417-
# If the k/v_scale appears in the checkpoint, it will be
418-
# overwritten when loading weights.
419-
layer.k_scale = Parameter(torch.tensor(-1.0), requires_grad=False)
420-
layer.v_scale = Parameter(torch.tensor(-1.0), requires_grad=False)
421-
422-
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
423-
raise RuntimeError("Fp8KVCacheMethod.apply should not be called.")
424-
425-
def process_weights_after_loading(self, layer: Module) -> None:
426-
# If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0
427-
# regardless whether the kv-scale is available in the checkpoint.
428-
if layer.kv_cache_dtype != "auto":
429-
if layer.k_scale > 0.0 and layer.v_scale > 0.0:
430-
# We prefer to use separate k_scale and v_scale if present
431-
k_scale = layer.k_scale.to("cpu").tolist()
432-
v_scale = layer.v_scale.to("cpu").tolist()
433-
elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
434-
# If no scales were loaded (both scales are invalid negative
435-
# values), use the default value of 1.0
436-
k_scale = Parameter(torch.tensor(1.0), requires_grad=False)
437-
v_scale = Parameter(torch.tensor(1.0), requires_grad=False)
438-
else:
439-
# If we find a single kv_scale in the checkpoint, we remap
440-
# kv_scale to k_scale during weight loading, and duplicate
441-
# k_scale to v_scale here
442-
assert layer.k_scale > 0.0
443-
scale_to_duplicate = max(layer.k_scale, layer.v_scale)
444-
k_scale = scale_to_duplicate.to("cpu").tolist()
445-
v_scale = scale_to_duplicate.to("cpu").tolist()
446-
447-
if not isinstance(k_scale, float) or not isinstance(
448-
v_scale, float):
449-
raise ValueError("Only support per-tensor scaling factor "
450-
"for fp8 KV cache")
451-
452-
# These are used in the final Attention.forward()
453-
layer._k_scale = k_scale
454-
layer._v_scale = v_scale
455-
if (layer._k_scale == 1.0 and layer._v_scale == 1.0
456-
and "e5m2" not in layer.kv_cache_dtype):
457-
print_warning_once(
458-
"Using KV cache scaling factor 1.0 for fp8_e4m3. This "
459-
"may cause accuracy issues. Please make sure k/v_scale "
460-
"scaling factors are available in the fp8 checkpoint.")
461-
462-
del layer.k_scale
463-
del layer.v_scale
410+
super().__init__(quant_config)
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import torch
2+
3+
from vllm.model_executor.layers.quantization.base_config import (
4+
QuantizationConfig, QuantizeMethodBase)
5+
from vllm.utils import print_warning_once
6+
7+
8+
class BaseKVCacheMethod(QuantizeMethodBase):
9+
"""
10+
Quant method that adds `_k_scale` and `_v_scale` attributes to the
11+
Attention layer to support loading those scaling factors from checkpoints.
12+
The k/v_scale will be used to:
13+
- quantize k/v_cache entries before saving them to the cache
14+
- dequantize k/v_cache entries before fetching them from the cache
15+
16+
:param quant_config: the appropriate QuantizationConfig
17+
"""
18+
19+
def __init__(self, quant_config: QuantizationConfig):
20+
self.quant_config = quant_config
21+
22+
def create_weights(self, layer: torch.nn.Module):
23+
"""
24+
Create "weight" (aka k_scale and v_scale) for an attention layer.
25+
"""
26+
# Initialize the KV cache scales to -1.0, which is an invalid value.
27+
# If the k/v_scale appears in the checkpoint, it will be
28+
# overwritten when loading weights.
29+
layer.k_scale = torch.nn.Parameter(torch.tensor(-1.0),
30+
requires_grad=False)
31+
layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0),
32+
requires_grad=False)
33+
34+
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
35+
raise RuntimeError(
36+
f"{self.__class__.__name__}.apply should not be called.")
37+
38+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
39+
# If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0
40+
# regardless whether the kv-scale is available in the checkpoint.
41+
if layer.kv_cache_dtype != "auto":
42+
if layer.k_scale > 0.0 and layer.v_scale > 0.0:
43+
# We prefer to use separate k_scale and v_scale if present
44+
k_scale = layer.k_scale.to("cpu").tolist()
45+
v_scale = layer.v_scale.to("cpu").tolist()
46+
elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
47+
# If no scales were loaded (both scales are invalid negative
48+
# values), use the default value of 1.0
49+
k_scale = torch.nn.Parameter(torch.tensor(1.0),
50+
requires_grad=False)
51+
v_scale = torch.nn.Parameter(torch.tensor(1.0),
52+
requires_grad=False)
53+
else:
54+
# If we find a single kv_scale in the checkpoint, we remap
55+
# kv_scale to k_scale during weight loading, and duplicate
56+
# k_scale to v_scale here
57+
assert layer.k_scale > 0.0
58+
scale_to_duplicate = max(layer.k_scale, layer.v_scale)
59+
k_scale = scale_to_duplicate.to("cpu").tolist()
60+
v_scale = scale_to_duplicate.to("cpu").tolist()
61+
62+
if not isinstance(k_scale, float) or not isinstance(
63+
v_scale, float):
64+
raise ValueError("Only support per-tensor scaling factor "
65+
"for fp8 KV cache")
66+
67+
# These are used in the final Attention.forward()
68+
layer._k_scale = k_scale
69+
layer._v_scale = v_scale
70+
if (layer._k_scale == 1.0 and layer._v_scale == 1.0
71+
and "e5m2" not in layer.kv_cache_dtype):
72+
print_warning_once(
73+
"Using KV cache scaling factor 1.0 for fp8_e4m3. This "
74+
"may cause accuracy issues. Please make sure k/v_scale "
75+
"scaling factors are available in the fp8 checkpoint.")
76+
77+
del layer.k_scale
78+
del layer.v_scale

vllm/model_executor/models/llama.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
from vllm.model_executor.layers.logits_processor import LogitsProcessor
4040
from vllm.model_executor.layers.quantization.base_config import (
4141
QuantizationConfig)
42+
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
43+
get_compressed_tensors_cache_scale)
4244
from vllm.model_executor.layers.rotary_embedding import get_rope
4345
from vllm.model_executor.layers.sampler import Sampler
4446
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -467,6 +469,14 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
467469
# Models trained using ColossalAI may include these tensors in
468470
# the checkpoint. Skip them.
469471
continue
472+
if scale_name := get_compressed_tensors_cache_scale(name):
473+
# Loading kv cache scales for compressed-tensors quantization
474+
param = params_dict[scale_name]
475+
weight_loader = getattr(param, "weight_loader",
476+
default_weight_loader)
477+
loaded_weight = loaded_weight[0]
478+
weight_loader(param, loaded_weight)
479+
continue
470480
for (param_name, weight_name, shard_id) in stacked_params_mapping:
471481
if weight_name not in name:
472482
continue

0 commit comments

Comments
 (0)