Skip to content

Commit 3e1c76c

Browse files
authored
Fix: Respect sparsity_config.ignore in Cutlass Integration (#12517)
This PR addresses a bug in the Cutlass integration where the `sparsity_config.ignore` list was not being respected. When only a subset of modules were configured as Sparse24, the system incorrectly selected Cutlass for non-sparse modules as well. This update ensures the correct scheme is selected for non-sparse modules, fixing this behavior. --- ### Changes - Updated logic to correctly respect `sparsity_config.ignore`. - Ensured non-sparse modules use the appropriate scheme instead of defaulting to Cutlass. --- <details> <summary>Testing Setup</summary> The fix has been tested on top of [this diff](#12097). #### Steps to Test: ```bash git checkout -b my-test-branch origin/rahul-bitmask-additions # compressed Cutlass support git revert --no-edit aa2cd2c # revert Tyler's commit to turn off Cutlass for W16A16 git cherry-pick ca624cd # this branch ``` #### Additional Patch Required: ```diff diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index a54177c1c..f916dd0c9 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -9,7 +9,7 @@ from compressed_tensors.quantization import (QuantizationArgs, QuantizationStrategy, QuantizationType) from pydantic import BaseModel - +from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) @@ -27,7 +27,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( should_ignore_layer) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.platforms import current_platform - +logger = init_logger(__name__) __all__ = ["CompressedTensorsLinearMethod"] SPARSITY_CONFIG_NAME: Literal["sparsity_config"] = "sparsity_config" ``` Apply using: ```bash git apply logging-patch.patch ``` </details> --- <details> <summary>Models Tested</summary> - `nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-partial-24` - `nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-full-sparse24` - `nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-partial-24-entire-fp8-compressed` - `nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-partial-24-remaining-fp8-compressed` </details> --- <details> <summary>Example Output</summary> #### Layers 0-5 (Sparse24) ``` Using scheme: CompressedTensors24 for model.layers.0.self_attn.qkv_proj Using scheme: CompressedTensors24 for model.layers.0.self_attn.o_proj Using scheme: CompressedTensors24 for model.layers.0.mlp.gate_up_proj Using scheme: CompressedTensors24 for model.layers.0.mlp.down_proj ... ``` #### Layers 6+ (Non-Sparse, FP8) ``` Using scheme: CompressedTensorsW8A8Fp8 for model.layers.6.self_attn.qkv_proj Using scheme: CompressedTensorsW8A8Fp8 for model.layers.6.self_attn.o_proj Using scheme: CompressedTensorsW8A8Fp8 for model.layers.6.mlp.gate_up_proj Using scheme: CompressedTensorsW8A8Fp8 for model.layers.6.mlp.down_proj ... ``` </details> **Note:** Assumed all modules in fused layers such as `QKV_proj` and `Gate_up_proj` follow the same quantization/pruning scheme. --- For related tasks using the Asana app for GitHub, refer to [[this link](https://app.asana.com/0/0/1209227810815160)](https://app.asana.com/0/0/1209227810815160). Signed-off-by: Rahul Tuli <[email protected]>
1 parent cfa134d commit 3e1c76c

File tree

2 files changed

+91
-28
lines changed

2 files changed

+91
-28
lines changed

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Any, Dict, List, Literal, Optional, cast
1+
from contextlib import suppress
2+
from typing import Any, Dict, List, Literal, Optional, Tuple, cast
23

34
import torch
45
from compressed_tensors.config import (CompressionFormat,
@@ -44,6 +45,7 @@ def __init__(
4445
ignore: List[str],
4546
quant_format: str,
4647
sparsity_scheme_map: Dict[str, SparsityCompressionConfig],
48+
sparsity_ignore_list: List[str],
4749
kv_cache_scheme: Optional[Dict[str, Any]] = None,
4850
config: Optional[Dict[str, Any]] = None,
4951
):
@@ -54,6 +56,7 @@ def __init__(
5456
self.target_scheme_map = target_scheme_map
5557
self.kv_cache_scheme = kv_cache_scheme
5658
self.sparsity_scheme_map = sparsity_scheme_map
59+
self.sparsity_ignore_list = sparsity_ignore_list
5760
self.config = config
5861

5962
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
@@ -98,36 +101,40 @@ def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
98101
quant_format = cast(str, config.get("format"))
99102
target_scheme_map = cls._quantization_scheme_map_from_config(
100103
config=config)
101-
sparsity_scheme_map = cls._sparsity_scheme_map_from_config(
104+
sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config(
102105
config=config)
103106

104107
return cls(
105108
target_scheme_map=target_scheme_map,
106109
ignore=ignore,
107110
quant_format=quant_format,
108111
sparsity_scheme_map=sparsity_scheme_map,
112+
sparsity_ignore_list=sparsity_ignore_list,
109113
config=config,
110114
)
111115

112116
@classmethod
113-
def _sparsity_scheme_map_from_config(
114-
cls, config: Dict[str,
115-
Any]) -> Dict[str, SparsityCompressionConfig]:
117+
def _parse_sparsity_config(
118+
cls, config: Dict[str, Any]
119+
) -> Tuple[Dict[str, SparsityCompressionConfig], List[str]]:
116120
"""
117121
:param config: The `quantization_config` dictionary from config.json
118-
:return: A dictionary mapping target layer names to their corresponding
119-
sparsity compression configurations
122+
:return: A tuple with two elements
123+
1. A dictionary mapping target layer names to their corresponding
124+
sparsity_config
125+
2. A list of layer names to ignore for sparsity
120126
"""
121127
if not (sparsity_config := config.get(SPARSITY_CONFIG_NAME)):
122-
return dict()
128+
return dict(), []
123129

124130
sparsity_config = SparsityCompressionConfig.model_validate(
125131
sparsity_config)
126132
sparse_scheme_map: Dict[str, SparsityCompressionConfig] = {
127133
target: sparsity_config
128134
for target in sparsity_config.targets or list()
129135
}
130-
return sparse_scheme_map
136+
sparsity_ignore_list = sparsity_config.ignore or list()
137+
return sparse_scheme_map, sparsity_ignore_list
131138

132139
@classmethod
133140
def _quantization_scheme_map_from_config(
@@ -352,7 +359,6 @@ def get_scheme(self,
352359
"""
353360
compressed-tensors supports non uniform in the following way:
354361
355-
ignore: List of layer_names or nn.Module names to be ignored.
356362
targets of config_groups: There can be N config_groups which each
357363
have a quantization scheme. Each config_group has a list of targets
358364
which can be a full layer_name, a regex for a layer_name, or
@@ -370,6 +376,8 @@ def get_scheme(self,
370376
# need to make accelerate optional in ct to do this
371377

372378
# Will be empty for models with only sparsity
379+
weight_quant = input_quant = None
380+
sparsity_scheme: Optional[SparsityCompressionConfig] = None
373381
if self.target_scheme_map:
374382
matched_target = find_matched_target(
375383
layer_name=layer_name,
@@ -379,19 +387,24 @@ def get_scheme(self,
379387
scheme_dict = self.target_scheme_map[matched_target]
380388
weight_quant = scheme_dict.get("weights")
381389
input_quant = scheme_dict.get("input_activations")
382-
elif self.sparsity_scheme_map:
383-
matched_target = find_matched_target(
384-
layer_name=layer_name,
385-
module=layer,
386-
targets=self.sparsity_scheme_map.keys())
387-
weight_quant = None
388-
input_quant = None
389390

390-
# For models with sparsity, assumes that the sparse layers are also
391-
# quantized for cutlass 2:4 support
392-
sparsity_scheme: Optional[
393-
SparsityCompressionConfig] = self.sparsity_scheme_map.get(
394-
matched_target)
391+
if self.sparsity_scheme_map:
392+
is_ignored = False
393+
with suppress(ValueError):
394+
is_ignored = find_matched_target(
395+
layer_name=layer_name,
396+
module=layer,
397+
targets=self.sparsity_ignore_list)
398+
399+
# if the layer is in the sparsity ignore list,
400+
# we should not apply any sparsity scheme
401+
402+
if not is_ignored:
403+
matched_target = find_matched_target(
404+
layer_name=layer_name,
405+
module=layer,
406+
targets=self.sparsity_scheme_map.keys())
407+
sparsity_scheme = self.sparsity_scheme_map.get(matched_target)
395408

396409
if self.supports_cutlass_24(weight_quant=weight_quant,
397410
input_quant=input_quant,
@@ -419,6 +432,8 @@ def get_scheme(self,
419432
# Raise error if device does not support the scheme
420433
# (e.g. fp8 needs ada lovelace)
421434
self._check_scheme_supported(scheme.get_min_capability())
435+
logger.debug("Using scheme: %s for %s", scheme.__class__.__name__,
436+
layer_name)
422437
return scheme
423438

424439
def get_cache_scale(self, name: str) -> Optional[str]:

vllm/model_executor/layers/quantization/compressed_tensors/utils.py

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def is_activation_quantization_format(format: str) -> bool:
1212
_ACTIVATION_QUANTIZATION_FORMATS = [
1313
CompressionFormat.naive_quantized.value,
1414
CompressionFormat.int_quantized.value,
15-
CompressionFormat.float_quantized.value
15+
CompressionFormat.float_quantized.value,
1616
]
1717
return format in _ACTIVATION_QUANTIZATION_FORMATS
1818

@@ -68,7 +68,7 @@ def should_ignore_layer(layer_name: Optional[str],
6868
def check_equal_or_regex_match(layer_name: str,
6969
targets: Iterable[str]) -> bool:
7070
"""
71-
Checks whether a layer_name is exactly equal or a regex match for
71+
Checks whether a layer_name is exactly equal or a regex match for
7272
if target starts with 're:' to any target in list.
7373
"""
7474
for target in targets:
@@ -77,17 +77,64 @@ def check_equal_or_regex_match(layer_name: str,
7777
return False
7878

7979

80+
def _handle_fused_layers(func):
81+
"""
82+
Decorator to handle fused layers by mapping vllm fused layer names
83+
to their corresponding unfused layer names for quantization/pruning schemes.
84+
"""
85+
# fused_layer_name -> unfused_layer_name
86+
fused_layer_map = {
87+
"qkv_proj": "q_proj",
88+
"gate_up_proj": "up_proj",
89+
}
90+
91+
def fused_layer_handler(layer_name: Optional[str], module: Module,
92+
targets: Iterable[str]) -> Optional[str]:
93+
"""
94+
Wrapper function specifically designed to support the
95+
find_matched_target function.
96+
97+
It handles cases where the provided layer name corresponds to a
98+
fused layer in vllm, mapping it to its equivalent unfused layer name
99+
based on the predefined fused_layer_map. If the original layer name
100+
raises a ValueError in the wrapped function, this handler
101+
will attempt to resolve the issue by substituting with unfused
102+
layer name.
103+
104+
:param layer_name: Name of the layer, which may be fused.
105+
:param module: An instance of torch.nn.Module.
106+
:param targets: A list of target names or patterns to match.
107+
:return: The result of the wrapped find_matched_target function with
108+
the resolved layer name.
109+
:raises ValueError: If the layer name cannot be resolved to a
110+
valid target.
111+
"""
112+
try:
113+
return func(layer_name, module, targets)
114+
except ValueError:
115+
if layer_name is None:
116+
layer_name = ""
117+
parent_name, fused_proj_name = layer_name.rsplit(".", 1)
118+
unfused_proj_name = fused_layer_map.get(fused_proj_name,
119+
fused_proj_name)
120+
new_layer_name = f"{parent_name}.{unfused_proj_name}"
121+
return func(new_layer_name, module, targets)
122+
123+
return fused_layer_handler
124+
125+
126+
@_handle_fused_layers
80127
def find_matched_target(layer_name: Optional[str], module: Module,
81128
targets: Iterable[str]) -> str:
82129
"""
83130
Helper function to look up which "target" in the compressed-tensors
84131
config that a layer corresponds to.
85132
86-
Recall that a compressed-tensors configs has a concept of
133+
Recall that a compressed-tensors configs has a concept of
87134
config_groups, where each layer can be quantized with with a different
88135
scheme.
89136
90-
targets in each config_group will be a list of either layer names
137+
targets in each config_group will be a list of either layer names
91138
(or regexes corresponding to layer names) or names of torch Modules.
92139
93140
First, we try to match the layer_name with a target
@@ -107,8 +154,9 @@ def find_matched_target(layer_name: Optional[str], module: Module,
107154
or _match_fused_layer(layer_name, targets))
108155

109156
if matched_target is None:
110-
raise ValueError(f"Unable to find matching target for {module} in the "
111-
"compressed-tensors config.")
157+
raise ValueError(
158+
f"Unable to find matching target for {layer_name} in the "
159+
"compressed-tensors config.")
112160

113161
return matched_target
114162

0 commit comments

Comments
 (0)