Skip to content

Commit c2637a6

Browse files
[Kernel] w4a16 support for compressed-tensors (#5385)
Co-authored-by: Robert Shaw <[email protected]>
1 parent 8840753 commit c2637a6

File tree

4 files changed

+230
-10
lines changed

4 files changed

+230
-10
lines changed

tests/quantization/test_compressed_tensors.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
Run `pytest tests/quantization/test_compressed_tensors.py`.
44
"""
55

6+
import pytest
67
import torch
78

89
from vllm import SamplingParams
910
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
10-
CompressedTensorsLinearMethod, CompressedTensorsW8A8DynamicToken,
11-
CompressedTensorsW8A8StaticTensor)
11+
CompressedTensorsLinearMethod, CompressedTensorsW4A16,
12+
CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor)
1213

1314

1415
def test_compressed_tensors_w8a8_static_setup(vllm_runner):
@@ -60,3 +61,25 @@ def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner):
6061
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
6162
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8DynamicToken)
6263
assert qkv_proj.weight.dtype is torch.int8
64+
65+
66+
@pytest.mark.parametrize("w4a16_args", [
67+
("nm-testing/tinyllama-oneshot-w4a16-channel-v2", "channel", None),
68+
("nm-testing/tinyllama-oneshot-w4a16-group128-v2", "group", 128),
69+
])
70+
def test_compressed_tensors_w4a16(vllm_runner, w4a16_args):
71+
model, strategy, group = w4a16_args
72+
with vllm_runner(model) as llm:
73+
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
74+
layer = model.model.layers[0]
75+
76+
qkv_proj = layer.self_attn.qkv_proj
77+
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
78+
assert isinstance(qkv_proj.scheme, CompressedTensorsW4A16)
79+
80+
assert qkv_proj.scheme.strategy == strategy
81+
assert qkv_proj.scheme.group_size == group
82+
83+
assert qkv_proj.weight_packed.dtype is torch.int32
84+
assert qkv_proj.weight_scale.dtype is torch.float16
85+
assert qkv_proj.weight_packed.pack_factor == 8

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
88
QuantizationConfig)
99
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
10-
CompressedTensorsScheme, CompressedTensorsW8A8DynamicToken,
11-
CompressedTensorsW8A8StaticTensor)
10+
CompressedTensorsScheme, CompressedTensorsW4A16,
11+
CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor)
1212
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
1313
QuantizationArgs, QuantizationStrategy, find_first_name_or_class_match)
1414

@@ -47,16 +47,27 @@ def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
4747
layer_quant_details: Dict[str, Any] = dict()
4848
ignore: List[str] = config.get("ignore", None)
4949

50+
# The quant_config has multiple config_groups, each containing
51+
# an input_activations key with details about how the activations are
52+
# quantized, a weights key indicating how the weights are quantized,
53+
# and a list of targets under the `targets` key, dictating which
54+
# layers are impacted by the quantization details. The quantization
55+
# details follow the structure defined by the QuantizationArgs
56+
# pydantic model, which is used to verify the structure of the
57+
# quant_config and also store the details for later use.
5058
for key, quant_config in config["config_groups"].items():
5159
targets = quant_config.get("targets")
5260
for target in targets:
5361
layer_quant_details[target] = {}
5462
layer_quant_details[target][
55-
"weight"] = QuantizationArgs.parse_obj(
63+
"weights"] = QuantizationArgs.parse_obj(
5664
quant_config.get("weights"))
57-
layer_quant_details[target][
58-
"input"] = QuantizationArgs.parse_obj(
59-
quant_config.get("input_activations"))
65+
try:
66+
layer_quant_details[target][
67+
"input_activations"] = QuantizationArgs.parse_obj(
68+
quant_config.get("input_activations"))
69+
except Exception:
70+
layer_quant_details[target]["input_activations"] = None
6071

6172
return cls(layer_quant_details=layer_quant_details, ignore=ignore)
6273

@@ -86,8 +97,23 @@ def _is_dynamic_token_w8a8(self, weight_quant: BaseModel,
8697

8798
return is_8_bits and is_token_tensor and is_symmetric and is_dynamic
8899

100+
def _is_w4a16(self, weight_quant: BaseModel,
101+
input_quant: BaseModel) -> bool:
102+
input_quant_none = input_quant is None
103+
is_4_bits = weight_quant.num_bits == 4
104+
is_symmetric = weight_quant.symmetric
105+
is_static = not weight_quant.dynamic
106+
107+
return is_4_bits and input_quant_none and is_symmetric and is_static
108+
89109
def _get_schema(self, weight_quant: BaseModel,
90110
input_quant: BaseModel) -> "CompressedTensorsScheme":
111+
112+
if self._is_w4a16(weight_quant, input_quant):
113+
return CompressedTensorsW4A16(num_bits=weight_quant.num_bits,
114+
strategy=weight_quant.strategy,
115+
group_size=weight_quant.group_size)
116+
91117
if self._is_static_tensor_w8a8(weight_quant, input_quant):
92118
return CompressedTensorsW8A8StaticTensor()
93119

@@ -113,8 +139,9 @@ def get_scheme(self, layer: torch.nn.Module) -> "CompressedTensorsScheme":
113139
raise ValueError(
114140
f"Could not find quantization details for {layer}.")
115141

116-
return self._get_schema(weight_quant=layer_quant_details["weight"],
117-
input_quant=layer_quant_details["input"])
142+
return self._get_schema(
143+
weight_quant=layer_quant_details["weights"],
144+
input_quant=layer_quant_details["input_activations"])
118145

119146

120147
class CompressedTensorsLinearMethod(LinearMethodBase):
@@ -140,6 +167,7 @@ def create_weights(self, layer: torch.nn.Module,
140167
layer=layer,
141168
input_size_per_partition=input_size_per_partition,
142169
output_partition_sizes=output_partition_sizes,
170+
input_size=input_size,
143171
output_size=output_size,
144172
params_dtype=params_dtype,
145173
weight_loader=weight_loader)

vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .compressed_tensors_scheme import CompressedTensorsScheme # noqa: F401
22
from .compressed_tensors_unquantized import ( # noqa: F401
33
CompressedTensorsUnquantized)
4+
from .compressed_tensors_w4a16 import CompressedTensorsW4A16 # noqa: F401
45
from .compressed_tensors_w8a8_dynamictoken import ( # noqa: F401, E501
56
CompressedTensorsW8A8DynamicToken)
67
from .compressed_tensors_w8a8_statictensor import ( # noqa: F401, E501
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
from typing import Callable, List, Optional
2+
3+
import torch
4+
from torch.nn import Parameter
5+
6+
from vllm import _custom_ops as ops
7+
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
8+
CompressedTensorsScheme)
9+
from vllm.model_executor.layers.quantization.gptq_marlin import (
10+
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, GPTQMarlinState,
11+
marlin_permute_scales)
12+
from vllm.model_executor.utils import set_weight_attrs
13+
14+
__all__ = ["CompressedTensorsW4A16"]
15+
16+
17+
class CompressedTensorsW4A16(CompressedTensorsScheme):
18+
19+
def __init__(self,
20+
strategy: str,
21+
num_bits: int,
22+
group_size: Optional[int] = None):
23+
self.num_bits = num_bits
24+
self.strategy = strategy
25+
self.group_size = group_size
26+
27+
if self.strategy == "group" and self.group_size is None:
28+
raise ValueError(
29+
"group_size must be given when using strategy group")
30+
31+
def create_weights(self, layer: torch.nn.Module, input_size: int,
32+
output_partition_sizes: List[int],
33+
input_size_per_partition: int,
34+
params_dtype: torch.dtype, weight_loader: Callable,
35+
**kwargs):
36+
37+
pack_factor = 32 // self.num_bits
38+
output_size_per_partition = sum(output_partition_sizes)
39+
40+
if self.group_size is not None:
41+
group_size = self.group_size
42+
else:
43+
group_size = input_size
44+
45+
weight_scale_dim = None
46+
scales_and_zp_size = input_size // group_size
47+
48+
if (input_size != input_size_per_partition
49+
and self.group_size is not None):
50+
weight_scale_dim = 1
51+
scales_and_zp_size = input_size_per_partition // group_size
52+
53+
weight = Parameter(
54+
torch.empty(
55+
output_size_per_partition,
56+
input_size_per_partition // pack_factor,
57+
dtype=torch.int32,
58+
),
59+
requires_grad=False,
60+
)
61+
62+
set_weight_attrs(
63+
weight, {
64+
"input_dim": 1,
65+
"output_dim": 0,
66+
"packed_dim": 1,
67+
"pack_factor": pack_factor
68+
})
69+
set_weight_attrs(weight, {"weight_loader": weight_loader})
70+
71+
layer.register_parameter("weight_packed", weight)
72+
73+
weight_scale = Parameter(
74+
torch.empty(
75+
output_size_per_partition,
76+
scales_and_zp_size,
77+
dtype=params_dtype,
78+
),
79+
requires_grad=False,
80+
)
81+
82+
set_weight_attrs(weight_scale, {"weight_loader": weight_loader})
83+
set_weight_attrs(weight_scale, {
84+
"input_dim": weight_scale_dim,
85+
"output_dim": 0
86+
})
87+
layer.register_parameter("weight_scale", weight_scale)
88+
89+
# A 2D array defining the original shape of the weights
90+
# before packing
91+
weight_shape = Parameter(torch.empty(2, dtype=torch.int64),
92+
requires_grad=False)
93+
94+
layer.register_parameter("weight_shape", weight_shape)
95+
set_weight_attrs(weight_shape, {"weight_loader": weight_loader})
96+
97+
layer.input_size_per_partition = input_size_per_partition
98+
layer.output_size_per_partition = output_size_per_partition
99+
100+
layer.input_size = input_size
101+
layer.marlin_state = GPTQMarlinState.REPACK
102+
layer.is_k_full = True
103+
layer.group_size = group_size
104+
105+
max_workspace_size = (
106+
output_size_per_partition //
107+
GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL
108+
109+
workspace = torch.zeros(max_workspace_size,
110+
dtype=torch.int,
111+
requires_grad=False)
112+
layer.workspace = workspace
113+
114+
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
115+
reshaped_x = x.reshape(-1, x.shape[-1])
116+
117+
size_m = reshaped_x.shape[0]
118+
part_size_n = layer.output_size_per_partition
119+
part_size_k = layer.input_size_per_partition
120+
121+
out_shape = x.shape[:-1] + (part_size_n, )
122+
123+
if layer.marlin_state == GPTQMarlinState.REPACK:
124+
layer.marlin_state = GPTQMarlinState.READY
125+
126+
# Newly generated tensors need to replace existing tensors that are
127+
# already registered as parameters by vLLM (and won't be freed)
128+
def replace_tensor(name, new_t):
129+
# It is important to use resize_() here since it ensures
130+
# the same buffer is reused
131+
getattr(layer, name).resize_(new_t.shape)
132+
getattr(layer, name).copy_(new_t)
133+
del new_t
134+
135+
cur_device = layer.weight_packed.device
136+
137+
# Reset g_idx related tensors
138+
layer.g_idx = Parameter(torch.empty(0,
139+
dtype=torch.int,
140+
device=cur_device),
141+
requires_grad=False)
142+
layer.g_idx_sort_indices = Parameter(torch.empty(
143+
0, dtype=torch.int, device=cur_device),
144+
requires_grad=False)
145+
146+
# Repack weights
147+
marlin_qweight = ops.gptq_marlin_repack(
148+
layer.weight_packed.t().contiguous(), layer.g_idx_sort_indices,
149+
part_size_k, part_size_n, self.num_bits)
150+
151+
replace_tensor("weight_packed", marlin_qweight)
152+
153+
# Permute scales
154+
scales_size_k = part_size_k
155+
scales_size_n = part_size_n
156+
157+
marlin_scales = marlin_permute_scales(
158+
layer.weight_scale.squeeze().t().contiguous(), scales_size_k,
159+
scales_size_n, layer.group_size, self.num_bits)
160+
replace_tensor("weight_scale", marlin_scales)
161+
162+
output = ops.gptq_marlin_gemm(reshaped_x, layer.weight_packed,
163+
layer.weight_scale, layer.g_idx,
164+
layer.g_idx_sort_indices,
165+
layer.workspace, self.num_bits, size_m,
166+
part_size_n, part_size_k,
167+
layer.is_k_full)
168+
return output.reshape(out_shape)

0 commit comments

Comments
 (0)