Skip to content

Commit ec60dfa

Browse files
authored
[Quantization] Support mixed-precision compression (#1713)
SUMMARY: - Requires neuralmagic/compressed-tensors#415 - Updates `infer_quantization_format` to be `infer_per_module_quantization_format` such that instead of returning a global format, a per module format is assigned to each module to be used during compression time. All unique compression formats are returned
1 parent 8edae24 commit ec60dfa

File tree

3 files changed

+97
-82
lines changed

3 files changed

+97
-82
lines changed
Lines changed: 83 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,71 @@
1-
from typing import Optional
1+
from typing import List, Optional
22

33
from compressed_tensors import CompressionFormat
44
from compressed_tensors.config import SparsityStructure
5-
from compressed_tensors.quantization import QuantizationStrategy, QuantizationType
5+
from compressed_tensors.quantization import (
6+
QuantizationArgs,
7+
QuantizationStrategy,
8+
QuantizationType,
9+
)
610
from compressed_tensors.quantization.utils import is_module_quantized
11+
from loguru import logger
712

8-
__all__ = ["infer_quantization_format"]
13+
__all__ = ["infer_and_set_per_module_quantization_format"]
914

1015

11-
def infer_quantization_format(
16+
def _get_quant_compression_format(
17+
input_args: QuantizationArgs,
18+
weight_args: QuantizationArgs,
19+
sparsity_structure: Optional[str] = None,
20+
):
21+
is_24_structure = (
22+
SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR
23+
)
24+
is_weight_only = weight_args is not None and input_args is None
25+
26+
if weight_args.num_bits == 4 and weight_args.type == QuantizationType.FLOAT.value:
27+
return CompressionFormat.nvfp4_pack_quantized
28+
29+
if is_weight_only: # w4a16 and w8a16
30+
is_valid_pack = (
31+
weight_args.num_bits in [4, 8]
32+
and weight_args.type == QuantizationType.INT.value
33+
)
34+
if not is_valid_pack: # packing only valid for int4 and int 8
35+
return CompressionFormat.naive_quantized
36+
if is_24_structure:
37+
if (
38+
weight_args.strategy is not QuantizationStrategy.CHANNEL.value
39+
and weight_args.strategy is not QuantizationStrategy.GROUP.value
40+
):
41+
# marlin24 kernel only applicable for channel/group quantization
42+
return CompressionFormat.pack_quantized
43+
return CompressionFormat.marlin_24
44+
return CompressionFormat.pack_quantized
45+
46+
else: # w8a8 float and int
47+
if (
48+
weight_args.type == QuantizationType.FLOAT.value
49+
and weight_args.num_bits == 8
50+
):
51+
return CompressionFormat.float_quantized
52+
if weight_args.type == QuantizationType.INT.value:
53+
return CompressionFormat.int_quantized
54+
55+
return CompressionFormat.naive_quantized
56+
57+
58+
def infer_and_set_per_module_quantization_format(
1259
model,
1360
quantization_format: Optional[str] = None,
1461
save_compressed: bool = False,
1562
sparsity_structure: Optional[str] = None,
16-
) -> str:
63+
) -> Optional[List[str]]:
1764
"""
1865
Infers the quantization format for a model based on its state and provided
19-
compression arguments.
66+
compression arguments. Also updates thhe quantization_scheme.format value
67+
based on the inferred format. Returns the unique list of formats in the model
68+
or None if empty list
2069
2170
For a summary of the formats, see `docs/guides/compression_formats.md`.
2271
@@ -27,74 +76,39 @@ def infer_quantization_format(
2776
:param save_compressed: used to infer a quantization format if None is provided
2877
:return compression format appropriate for model
2978
"""
30-
if quantization_format is not None:
31-
return quantization_format
3279

33-
weight_args, input_args = _get_unique_quant_args(model)
34-
if len(weight_args) <= 0:
80+
if not save_compressed:
3581
return None
3682

37-
if save_compressed:
38-
is_24_structure = (
39-
SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR
40-
)
41-
is_weight_only = len(input_args) == 0 and len(weight_args) > 0
83+
if quantization_format:
84+
return [quantization_format]
4285

43-
if (
44-
weight_args[0].num_bits == 4
45-
and weight_args[0].type == QuantizationType.FLOAT.value
46-
):
47-
return CompressionFormat.nvfp4_pack_quantized
48-
49-
if is_weight_only: # w4a16 and w8a16
50-
is_valid_pack = all(
51-
weight_arg.num_bits in [4, 8]
52-
and weight_arg.type == QuantizationType.INT.value
53-
for weight_arg in weight_args
54-
)
55-
if not is_valid_pack: # packing only valid for int4 and int 8
56-
return CompressionFormat.naive_quantized
57-
if is_24_structure:
58-
for arg in weight_args:
59-
if (
60-
arg.strategy is not QuantizationStrategy.CHANNEL.value
61-
and arg.strategy is not QuantizationStrategy.GROUP.value
62-
):
63-
# marlin24 kernel only applicable for channel/group quantization
64-
return CompressionFormat.pack_quantized
65-
return CompressionFormat.marlin_24
66-
return CompressionFormat.pack_quantized
67-
else: # w8a8 float and int
68-
if len(weight_args) == 1:
69-
if (
70-
weight_args[0].type == QuantizationType.FLOAT.value
71-
and weight_args[0].num_bits == 8
72-
):
73-
return CompressionFormat.float_quantized
74-
if weight_args[0].type == QuantizationType.INT.value:
75-
return CompressionFormat.int_quantized
76-
77-
return CompressionFormat.naive_quantized
78-
else:
79-
# format will be inferred from config
80-
return None
81-
82-
83-
def _get_unique_quant_args(model):
84-
"""
85-
Gets a list of all the unique quantization settings present in model
86-
"""
87-
quant_info_weight = []
88-
quant_info_inputs = []
86+
unique_formats = []
8987
for submodule in model.modules():
9088
if is_module_quantized(submodule):
9189
weight_scheme = submodule.quantization_scheme.weights
9290
input_scheme = submodule.quantization_scheme.input_activations
93-
if weight_scheme is not None:
94-
if weight_scheme not in quant_info_weight:
95-
quant_info_weight.append(weight_scheme)
96-
if input_scheme is not None:
97-
if input_scheme not in quant_info_inputs:
98-
quant_info_inputs.append(input_scheme)
99-
100-
return quant_info_weight, quant_info_inputs
91+
if weight_scheme is None:
92+
continue # no weight quant - nothing to compress
93+
compression_format = _get_quant_compression_format(
94+
input_scheme, weight_scheme, sparsity_structure
95+
)
96+
97+
# If set, we check if it matches our inferred one
98+
if submodule.quantization_scheme.format is not None:
99+
# If it does not, warn the user
100+
if submodule.quantization_scheme.format != compression_format.value:
101+
logger.warning(
102+
"The provided format for the module does not match the "
103+
"inferred format. Compression may fail "
104+
)
105+
else:
106+
# If not set, we set ours
107+
submodule.quantization_scheme.format = compression_format.value
108+
109+
if submodule.quantization_scheme.format not in unique_formats:
110+
unique_formats.append(submodule.quantization_scheme.format)
111+
112+
if len(unique_formats) > 0:
113+
return unique_formats
114+
return None

src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import os
22
import weakref
33
from functools import wraps
4-
from typing import Optional
4+
from typing import List, Optional
55

66
import torch
77
from accelerate.accelerator import get_state_dict_offloaded_model
88
from compressed_tensors import (
9-
CompressionFormat,
109
ModelCompressor,
1110
SparsityCompressionConfig,
1211
delete_offload_parameter,
@@ -19,7 +18,7 @@
1918
from llmcompressor.core import active_session
2019
from llmcompressor.pytorch.model_load.helpers import copy_python_files_from_model_cache
2120
from llmcompressor.transformers.compression.quantization_format import (
22-
infer_quantization_format,
21+
infer_and_set_per_module_quantization_format,
2322
)
2423
from llmcompressor.transformers.compression.sparsity_metadata_config import (
2524
SparsityConfigMetadata,
@@ -228,13 +227,15 @@ def get_model_compressor(
228227
SparsityConfigMetadata.infer_sparsity_structure(model)
229228
)
230229

231-
quantization_format: Optional[CompressionFormat] = infer_quantization_format(
232-
model=model,
233-
quantization_format=quantization_format,
234-
save_compressed=save_compressed,
235-
sparsity_structure=None
236-
if sparsity_config is None
237-
else sparsity_config.sparsity_structure,
230+
quantization_format: Optional[List[str]] = (
231+
infer_and_set_per_module_quantization_format(
232+
model=model,
233+
quantization_format=quantization_format,
234+
save_compressed=save_compressed,
235+
sparsity_structure=None
236+
if sparsity_config is None
237+
else sparsity_config.sparsity_structure,
238+
)
238239
)
239240

240241
return ModelCompressor.from_pretrained_model(

tests/llmcompressor/transformers/compression/test_infer_quant_format.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from compressed_tensors.quantization import preset_name_to_scheme
33

44
from llmcompressor.transformers.compression.quantization_format import (
5-
infer_quantization_format,
5+
infer_and_set_per_module_quantization_format,
66
)
77
from tests.llmcompressor.pytorch.helpers import LinearNet
88

@@ -25,7 +25,7 @@ def test_infer_quant_format(preset, sparsity_structure, expected_format):
2525
for _, module in dummy_model.named_modules():
2626
module.quantization_scheme = quant_scheme
2727

28-
inferred_format = infer_quantization_format(
28+
inferred_format = infer_and_set_per_module_quantization_format(
2929
dummy_model, save_compressed=True, sparsity_structure=sparsity_structure
3030
)
31-
assert inferred_format.value == expected_format
31+
assert inferred_format[0] == expected_format

0 commit comments

Comments
 (0)