@@ -24,6 +24,8 @@ def _get_quant_compression_format(
2424 is_weight_only = weight_args is not None and input_args is None
2525
2626 if weight_args .num_bits == 4 and weight_args .type == QuantizationType .FLOAT .value :
27+ if weight_args .is_mx :
28+ return CompressionFormat .mxfp4_pack_quantized
2729 return CompressionFormat .nvfp4_pack_quantized
2830
2931 if is_weight_only : # w4a16 and w8a16
@@ -55,6 +57,30 @@ def _get_quant_compression_format(
5557 return CompressionFormat .naive_quantized
5658
5759
60+ def _get_unique_quant_args (model ):
61+ """
62+ Gets a list of all the unique quantization settings present in model
63+ """
64+ from compressed_tensors .quantization .utils import (
65+ is_model_quantized ,
66+ is_module_quantized ,
67+ iter_named_leaf_modules ,
68+ )
69+ quant_info_weight = []
70+ quant_info_inputs = []
71+ for _ , submodule in iter_named_leaf_modules (model ):
72+ if is_module_quantized (submodule ):
73+ weight_scheme = submodule .quantization_scheme .weights
74+ input_scheme = submodule .quantization_scheme .input_activations
75+ if weight_scheme is not None :
76+ if weight_scheme not in quant_info_weight :
77+ quant_info_weight .append (weight_scheme )
78+ if input_scheme is not None :
79+ if input_scheme not in quant_info_inputs :
80+ quant_info_inputs .append (input_scheme )
81+
82+ return quant_info_weight , quant_info_inputs
83+
5884def infer_and_set_per_module_quantization_format (
5985 model ,
6086 quantization_format : Optional [str ] = None ,
@@ -79,50 +105,50 @@ def infer_and_set_per_module_quantization_format(
79105
80106 if not save_compressed :
81107 return None
82- if save_compressed :
83- weight_args , input_args = _get_unique_quant_args (model )
84- is_24_structure = (
85- SparsityStructure (sparsity_structure ) == SparsityStructure .TWO_FOUR
86- )
87- is_weight_only = len (input_args ) == 0 and len (weight_args ) > 0
88- if (
89- weight_args [0 ].num_bits == 4
90- and weight_args [0 ].type == QuantizationType .FLOAT .value
91- ):
92- if weight_args [0 ].is_mx :
93- return CompressionFormat .mxfp4_pack_quantized
94- else :
95- return CompressionFormat .nvfp4_pack_quantized
96-
97- if is_weight_only : # w4a16 and w8a16
98- is_valid_pack = all (
99- weight_arg .num_bits in [4 , 8 ]
100- and weight_arg .type == QuantizationType .INT .value
101- for weight_arg in weight_args
102- )
103- if not is_valid_pack : # packing only valid for int4 and int 8
104- return CompressionFormat .naive_quantized
105- if is_24_structure :
106- for arg in weight_args :
107- if (
108- arg .strategy is not QuantizationStrategy .CHANNEL .value
109- and arg .strategy is not QuantizationStrategy .GROUP .value
110- ):
111- # marlin24 kernel only applicable for channel/group quantization
112- return CompressionFormat .pack_quantized
113- return CompressionFormat .marlin_24
114- return CompressionFormat .pack_quantized
115- else : # w8a8 float and int
116- if len (weight_args ) == 1 :
117- if (
118- weight_args [0 ].type == QuantizationType .FLOAT .value
119- and weight_args [0 ].num_bits == 8
120- ):
121- return CompressionFormat .float_quantized
122- if weight_args [0 ].type == QuantizationType .INT .value :
123- return CompressionFormat .int_quantized
124-
125- return CompressionFormat .naive_quantized
108+ # if save_compressed:
109+ # weight_args, input_args = _get_unique_quant_args(model)
110+ # is_24_structure = (
111+ # SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR
112+ # )
113+ # is_weight_only = len(input_args) == 0 and len(weight_args) > 0
114+ # if (
115+ # weight_args[0].num_bits == 4
116+ # and weight_args[0].type == QuantizationType.FLOAT.value
117+ # ):
118+ # if weight_args[0].is_mx:
119+ # return CompressionFormat.mxfp4_pack_quantized
120+ # else:
121+ # return CompressionFormat.nvfp4_pack_quantized
122+
123+ # if is_weight_only: # w4a16 and w8a16
124+ # is_valid_pack = all(
125+ # weight_arg.num_bits in [4, 8]
126+ # and weight_arg.type == QuantizationType.INT.value
127+ # for weight_arg in weight_args
128+ # )
129+ # if not is_valid_pack: # packing only valid for int4 and int 8
130+ # return CompressionFormat.naive_quantized
131+ # if is_24_structure:
132+ # for arg in weight_args:
133+ # if (
134+ # arg.strategy is not QuantizationStrategy.CHANNEL.value
135+ # and arg.strategy is not QuantizationStrategy.GROUP.value
136+ # ):
137+ # # marlin24 kernel only applicable for channel/group quantization
138+ # return CompressionFormat.pack_quantized
139+ # return CompressionFormat.marlin_24
140+ # return CompressionFormat.pack_quantized
141+ # else: # w8a8 float and int
142+ # if len(weight_args) == 1:
143+ # if (
144+ # weight_args[0].type == QuantizationType.FLOAT.value
145+ # and weight_args[0].num_bits == 8
146+ # ):
147+ # return CompressionFormat.float_quantized
148+ # if weight_args[0].type == QuantizationType.INT.value:
149+ # return CompressionFormat.int_quantized
150+
151+ # return CompressionFormat.naive_quantized
126152
127153
128154 if quantization_format :
0 commit comments