Skip to content

Commit cf9843b

Browse files
committed
tmp fix
Signed-off-by: yiliu30 <yi4.liu@intel.com>
1 parent f950fb7 commit cf9843b

File tree

2 files changed

+71
-45
lines changed

2 files changed

+71
-45
lines changed

examples/quantization_w4a4_fp4/llama3_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def tokenize(sample):
8585
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(
8686
model.device
8787
)
88-
output = model.generate(input_ids, max_new_tokens=100)
88+
output = model.generate(input_ids, max_new_tokens=10)
8989

9090
print(tokenizer.decode(output[0]))
9191
print("==========================================\n\n")

src/llmcompressor/transformers/compression/quantization_format.py

Lines changed: 70 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ def _get_quant_compression_format(
2424
is_weight_only = weight_args is not None and input_args is None
2525

2626
if weight_args.num_bits == 4 and weight_args.type == QuantizationType.FLOAT.value:
27+
if weight_args.is_mx:
28+
return CompressionFormat.mxfp4_pack_quantized
2729
return CompressionFormat.nvfp4_pack_quantized
2830

2931
if is_weight_only: # w4a16 and w8a16
@@ -55,6 +57,30 @@ def _get_quant_compression_format(
5557
return CompressionFormat.naive_quantized
5658

5759

60+
def _get_unique_quant_args(model):
61+
"""
62+
Gets a list of all the unique quantization settings present in model
63+
"""
64+
from compressed_tensors.quantization.utils import (
65+
is_model_quantized,
66+
is_module_quantized,
67+
iter_named_leaf_modules,
68+
)
69+
quant_info_weight = []
70+
quant_info_inputs = []
71+
for _, submodule in iter_named_leaf_modules(model):
72+
if is_module_quantized(submodule):
73+
weight_scheme = submodule.quantization_scheme.weights
74+
input_scheme = submodule.quantization_scheme.input_activations
75+
if weight_scheme is not None:
76+
if weight_scheme not in quant_info_weight:
77+
quant_info_weight.append(weight_scheme)
78+
if input_scheme is not None:
79+
if input_scheme not in quant_info_inputs:
80+
quant_info_inputs.append(input_scheme)
81+
82+
return quant_info_weight, quant_info_inputs
83+
5884
def infer_and_set_per_module_quantization_format(
5985
model,
6086
quantization_format: Optional[str] = None,
@@ -79,50 +105,50 @@ def infer_and_set_per_module_quantization_format(
79105

80106
if not save_compressed:
81107
return None
82-
if save_compressed:
83-
weight_args, input_args = _get_unique_quant_args(model)
84-
is_24_structure = (
85-
SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR
86-
)
87-
is_weight_only = len(input_args) == 0 and len(weight_args) > 0
88-
if (
89-
weight_args[0].num_bits == 4
90-
and weight_args[0].type == QuantizationType.FLOAT.value
91-
):
92-
if weight_args[0].is_mx:
93-
return CompressionFormat.mxfp4_pack_quantized
94-
else:
95-
return CompressionFormat.nvfp4_pack_quantized
96-
97-
if is_weight_only: # w4a16 and w8a16
98-
is_valid_pack = all(
99-
weight_arg.num_bits in [4, 8]
100-
and weight_arg.type == QuantizationType.INT.value
101-
for weight_arg in weight_args
102-
)
103-
if not is_valid_pack: # packing only valid for int4 and int 8
104-
return CompressionFormat.naive_quantized
105-
if is_24_structure:
106-
for arg in weight_args:
107-
if (
108-
arg.strategy is not QuantizationStrategy.CHANNEL.value
109-
and arg.strategy is not QuantizationStrategy.GROUP.value
110-
):
111-
# marlin24 kernel only applicable for channel/group quantization
112-
return CompressionFormat.pack_quantized
113-
return CompressionFormat.marlin_24
114-
return CompressionFormat.pack_quantized
115-
else: # w8a8 float and int
116-
if len(weight_args) == 1:
117-
if (
118-
weight_args[0].type == QuantizationType.FLOAT.value
119-
and weight_args[0].num_bits == 8
120-
):
121-
return CompressionFormat.float_quantized
122-
if weight_args[0].type == QuantizationType.INT.value:
123-
return CompressionFormat.int_quantized
124-
125-
return CompressionFormat.naive_quantized
108+
# if save_compressed:
109+
# weight_args, input_args = _get_unique_quant_args(model)
110+
# is_24_structure = (
111+
# SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR
112+
# )
113+
# is_weight_only = len(input_args) == 0 and len(weight_args) > 0
114+
# if (
115+
# weight_args[0].num_bits == 4
116+
# and weight_args[0].type == QuantizationType.FLOAT.value
117+
# ):
118+
# if weight_args[0].is_mx:
119+
# return CompressionFormat.mxfp4_pack_quantized
120+
# else:
121+
# return CompressionFormat.nvfp4_pack_quantized
122+
123+
# if is_weight_only: # w4a16 and w8a16
124+
# is_valid_pack = all(
125+
# weight_arg.num_bits in [4, 8]
126+
# and weight_arg.type == QuantizationType.INT.value
127+
# for weight_arg in weight_args
128+
# )
129+
# if not is_valid_pack: # packing only valid for int4 and int 8
130+
# return CompressionFormat.naive_quantized
131+
# if is_24_structure:
132+
# for arg in weight_args:
133+
# if (
134+
# arg.strategy is not QuantizationStrategy.CHANNEL.value
135+
# and arg.strategy is not QuantizationStrategy.GROUP.value
136+
# ):
137+
# # marlin24 kernel only applicable for channel/group quantization
138+
# return CompressionFormat.pack_quantized
139+
# return CompressionFormat.marlin_24
140+
# return CompressionFormat.pack_quantized
141+
# else: # w8a8 float and int
142+
# if len(weight_args) == 1:
143+
# if (
144+
# weight_args[0].type == QuantizationType.FLOAT.value
145+
# and weight_args[0].num_bits == 8
146+
# ):
147+
# return CompressionFormat.float_quantized
148+
# if weight_args[0].type == QuantizationType.INT.value:
149+
# return CompressionFormat.int_quantized
150+
151+
# return CompressionFormat.naive_quantized
126152

127153

128154
if quantization_format:

0 commit comments

Comments
 (0)