Skip to content

Commit 62963d1

Browse files
[ Misc ] Clean Up CompressedTensorsW8A8 (#6113)
1 parent d9e98f4 commit 62963d1

File tree

6 files changed

+44
-95
lines changed

6 files changed

+44
-95
lines changed

tests/quantization/test_compressed_tensors.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99
from vllm import SamplingParams
1010
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
1111
CompressedTensorsLinearMethod, CompressedTensorsW4A16Sparse24,
12-
CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor,
13-
CompressedTensorsWNA16)
12+
CompressedTensorsW8A8, CompressedTensorsWNA16)
1413
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
1514
QuantizationType)
1615

@@ -38,9 +37,10 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
3837
CompressedTensorsLinearMethod)
3938
assert isinstance(down_proj.quant_method,
4039
CompressedTensorsLinearMethod)
41-
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8StaticTensor)
40+
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8)
4241

4342
assert qkv_proj.scheme.strategy == strategy
43+
assert qkv_proj.scheme.is_static_input_scheme
4444
expected_type = (torch.int8 if quant_type == QuantizationType.INT else
4545
torch.float8_e4m3fn)
4646

@@ -79,7 +79,8 @@ def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner, model_args):
7979
qkv_proj = layer.self_attn.qkv_proj
8080

8181
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
82-
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8DynamicToken)
82+
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8)
83+
assert not qkv_proj.scheme.is_static_input_scheme
8384
assert qkv_proj.scheme.strategy == strategy
8485
assert qkv_proj.weight.dtype is torch.int8
8586

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
1010
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS,
1111
CompressedTensorsScheme, CompressedTensorsW4A16Sparse24,
12-
CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor,
13-
CompressedTensorsWNA16)
12+
CompressedTensorsW8A8, CompressedTensorsWNA16)
1413
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
1514
CompressionFormat, QuantizationArgs, QuantizationStrategy,
1615
find_first_name_or_class_match)
@@ -150,12 +149,12 @@ def _get_schema(self, weight_quant: BaseModel,
150149

151150
if self.quant_format == CompressionFormat.int_quantized.value:
152151
if self._is_static_tensor_w8a8(weight_quant, input_quant):
153-
return CompressedTensorsW8A8StaticTensor(
154-
strategy=weight_quant.strategy)
152+
return CompressedTensorsW8A8(strategy=weight_quant.strategy,
153+
is_static_input_scheme=True)
155154

156155
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
157-
return CompressedTensorsW8A8DynamicToken(
158-
strategy=weight_quant.strategy)
156+
return CompressedTensorsW8A8(strategy=weight_quant.strategy,
157+
is_static_input_scheme=False)
159158

160159
raise NotImplementedError(
161160
"No compressed-tensors compatible scheme was found.")

vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,6 @@
33
CompressedTensorsUnquantized)
44
from .compressed_tensors_w4a16_24 import ( # noqa: F401
55
W4A16SPARSE24_SUPPORTED_BITS, CompressedTensorsW4A16Sparse24)
6-
from .compressed_tensors_w8a8_dynamictoken import ( # noqa: F401, E501
7-
CompressedTensorsW8A8DynamicToken)
8-
from .compressed_tensors_w8a8_statictensor import ( # noqa: F401, E501
9-
CompressedTensorsW8A8StaticTensor)
6+
from .compressed_tensors_w8a8 import CompressedTensorsW8A8 # noqa: F401
107
from .compressed_tensors_wNa16 import WNA16_SUPPORTED_BITS # noqa: F401
118
from .compressed_tensors_wNa16 import CompressedTensorsWNA16 # noqa: F401

vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44
from torch.nn import Parameter
55

6+
from vllm import _custom_ops as ops
67
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
78
CompressedTensorsScheme)
89
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
@@ -12,8 +13,9 @@
1213

1314
class CompressedTensorsW8A8(CompressedTensorsScheme):
1415

15-
def __init__(self, strategy: str):
16+
def __init__(self, strategy: str, is_static_input_scheme: bool):
1617
self.strategy = strategy
18+
self.is_static_input_scheme = is_static_input_scheme
1719

1820
# Cutlass kernels support only per-tensor and per-channel cases.
1921
# So if we have a fused module (QKV, MLP) with per tensor scales (thus N
@@ -36,6 +38,10 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
3638
layer.weight_scale = Parameter(weight_scale_channel,
3739
requires_grad=False)
3840

41+
# transpose weights for cutlass.
42+
weight = layer.weight
43+
layer.weight = Parameter(weight.t(), requires_grad=False)
44+
3945
def create_weights(self, layer: torch.nn.Module,
4046
output_partition_sizes: List[int],
4147
input_size_per_partition: int,
@@ -75,3 +81,29 @@ def create_weights(self, layer: torch.nn.Module,
7581
"output_dim": 0,
7682
"weight_loader": weight_loader,
7783
})
84+
85+
# INPUT SCALE
86+
# Static quantization: load from disk.
87+
if self.is_static_input_scheme:
88+
input_scale = Parameter(torch.empty(1, dtype=torch.float32),
89+
requires_grad=False)
90+
layer.register_parameter("input_scale", input_scale)
91+
set_weight_attrs(input_scale, {
92+
"weight_loader": weight_loader,
93+
"ignore_warning": True,
94+
})
95+
# Dynamic quantization: set to None.
96+
else:
97+
layer.input_scale = None
98+
99+
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
100+
# ops.scaled_int8_quant supports both dynamic and static quant.
101+
# * dynamic, layer.input_scale is None and x_scale computed from x.
102+
# * static, layer.input_scale is scalar and x_scale is input_scale.
103+
x_q, x_scale = ops.scaled_int8_quant(x, layer.input_scale)
104+
105+
return ops.cutlass_scaled_mm(x_q,
106+
layer.weight,
107+
scale_a=x_scale,
108+
scale_b=layer.weight_scale,
109+
out_dtype=x.dtype)

vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py

Lines changed: 0 additions & 33 deletions
This file was deleted.

vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py

Lines changed: 0 additions & 47 deletions
This file was deleted.

0 commit comments

Comments
 (0)