1
- from typing import Optional
1
+ from typing import List , Optional
2
2
3
3
from compressed_tensors import CompressionFormat
4
4
from compressed_tensors .config import SparsityStructure
5
- from compressed_tensors .quantization import QuantizationStrategy , QuantizationType
5
+ from compressed_tensors .quantization import (
6
+ QuantizationArgs ,
7
+ QuantizationStrategy ,
8
+ QuantizationType ,
9
+ )
6
10
from compressed_tensors .quantization .utils import is_module_quantized
11
+ from loguru import logger
7
12
8
- __all__ = ["infer_quantization_format " ]
13
+ __all__ = ["infer_and_set_per_module_quantization_format " ]
9
14
10
15
11
- def infer_quantization_format (
16
+ def _get_quant_compression_format (
17
+ input_args : QuantizationArgs ,
18
+ weight_args : QuantizationArgs ,
19
+ sparsity_structure : Optional [str ] = None ,
20
+ ):
21
+ is_24_structure = (
22
+ SparsityStructure (sparsity_structure ) == SparsityStructure .TWO_FOUR
23
+ )
24
+ is_weight_only = weight_args is not None and input_args is None
25
+
26
+ if weight_args .num_bits == 4 and weight_args .type == QuantizationType .FLOAT .value :
27
+ return CompressionFormat .nvfp4_pack_quantized
28
+
29
+ if is_weight_only : # w4a16 and w8a16
30
+ is_valid_pack = (
31
+ weight_args .num_bits in [4 , 8 ]
32
+ and weight_args .type == QuantizationType .INT .value
33
+ )
34
+ if not is_valid_pack : # packing only valid for int4 and int 8
35
+ return CompressionFormat .naive_quantized
36
+ if is_24_structure :
37
+ if (
38
+ weight_args .strategy is not QuantizationStrategy .CHANNEL .value
39
+ and weight_args .strategy is not QuantizationStrategy .GROUP .value
40
+ ):
41
+ # marlin24 kernel only applicable for channel/group quantization
42
+ return CompressionFormat .pack_quantized
43
+ return CompressionFormat .marlin_24
44
+ return CompressionFormat .pack_quantized
45
+
46
+ else : # w8a8 float and int
47
+ if (
48
+ weight_args .type == QuantizationType .FLOAT .value
49
+ and weight_args .num_bits == 8
50
+ ):
51
+ return CompressionFormat .float_quantized
52
+ if weight_args .type == QuantizationType .INT .value :
53
+ return CompressionFormat .int_quantized
54
+
55
+ return CompressionFormat .naive_quantized
56
+
57
+
58
+ def infer_and_set_per_module_quantization_format (
12
59
model ,
13
60
quantization_format : Optional [str ] = None ,
14
61
save_compressed : bool = False ,
15
62
sparsity_structure : Optional [str ] = None ,
16
- ) -> str :
63
+ ) -> Optional [ List [ str ]] :
17
64
"""
18
65
Infers the quantization format for a model based on its state and provided
19
- compression arguments.
66
+ compression arguments. Also updates thhe quantization_scheme.format value
67
+ based on the inferred format. Returns the unique list of formats in the model
68
+ or None if empty list
20
69
21
70
For a summary of the formats, see `docs/guides/compression_formats.md`.
22
71
@@ -27,74 +76,39 @@ def infer_quantization_format(
27
76
:param save_compressed: used to infer a quantization format if None is provided
28
77
:return compression format appropriate for model
29
78
"""
30
- if quantization_format is not None :
31
- return quantization_format
32
79
33
- weight_args , input_args = _get_unique_quant_args (model )
34
- if len (weight_args ) <= 0 :
80
+ if not save_compressed :
35
81
return None
36
82
37
- if save_compressed :
38
- is_24_structure = (
39
- SparsityStructure (sparsity_structure ) == SparsityStructure .TWO_FOUR
40
- )
41
- is_weight_only = len (input_args ) == 0 and len (weight_args ) > 0
83
+ if quantization_format :
84
+ return [quantization_format ]
42
85
43
- if (
44
- weight_args [0 ].num_bits == 4
45
- and weight_args [0 ].type == QuantizationType .FLOAT .value
46
- ):
47
- return CompressionFormat .nvfp4_pack_quantized
48
-
49
- if is_weight_only : # w4a16 and w8a16
50
- is_valid_pack = all (
51
- weight_arg .num_bits in [4 , 8 ]
52
- and weight_arg .type == QuantizationType .INT .value
53
- for weight_arg in weight_args
54
- )
55
- if not is_valid_pack : # packing only valid for int4 and int 8
56
- return CompressionFormat .naive_quantized
57
- if is_24_structure :
58
- for arg in weight_args :
59
- if (
60
- arg .strategy is not QuantizationStrategy .CHANNEL .value
61
- and arg .strategy is not QuantizationStrategy .GROUP .value
62
- ):
63
- # marlin24 kernel only applicable for channel/group quantization
64
- return CompressionFormat .pack_quantized
65
- return CompressionFormat .marlin_24
66
- return CompressionFormat .pack_quantized
67
- else : # w8a8 float and int
68
- if len (weight_args ) == 1 :
69
- if (
70
- weight_args [0 ].type == QuantizationType .FLOAT .value
71
- and weight_args [0 ].num_bits == 8
72
- ):
73
- return CompressionFormat .float_quantized
74
- if weight_args [0 ].type == QuantizationType .INT .value :
75
- return CompressionFormat .int_quantized
76
-
77
- return CompressionFormat .naive_quantized
78
- else :
79
- # format will be inferred from config
80
- return None
81
-
82
-
83
- def _get_unique_quant_args (model ):
84
- """
85
- Gets a list of all the unique quantization settings present in model
86
- """
87
- quant_info_weight = []
88
- quant_info_inputs = []
86
+ unique_formats = []
89
87
for submodule in model .modules ():
90
88
if is_module_quantized (submodule ):
91
89
weight_scheme = submodule .quantization_scheme .weights
92
90
input_scheme = submodule .quantization_scheme .input_activations
93
- if weight_scheme is not None :
94
- if weight_scheme not in quant_info_weight :
95
- quant_info_weight .append (weight_scheme )
96
- if input_scheme is not None :
97
- if input_scheme not in quant_info_inputs :
98
- quant_info_inputs .append (input_scheme )
99
-
100
- return quant_info_weight , quant_info_inputs
91
+ if weight_scheme is None :
92
+ continue # no weight quant - nothing to compress
93
+ compression_format = _get_quant_compression_format (
94
+ input_scheme , weight_scheme , sparsity_structure
95
+ )
96
+
97
+ # If set, we check if it matches our inferred one
98
+ if submodule .quantization_scheme .format is not None :
99
+ # If it does not, warn the user
100
+ if submodule .quantization_scheme .format != compression_format .value :
101
+ logger .warning (
102
+ "The provided format for the module does not match the "
103
+ "inferred format. Compression may fail "
104
+ )
105
+ else :
106
+ # If not set, we set ours
107
+ submodule .quantization_scheme .format = compression_format .value
108
+
109
+ if submodule .quantization_scheme .format not in unique_formats :
110
+ unique_formats .append (submodule .quantization_scheme .format )
111
+
112
+ if len (unique_formats ) > 0 :
113
+ return unique_formats
114
+ return None
0 commit comments