1+ import json
12import os
23import shutil
34from pathlib import Path
1415from compressed_tensors .quantization import QuantizationScheme
1516from loguru import logger
1617
17- from llmcompressor .entrypoints .model_free .helpers import gpu_if_available
18+ from llmcompressor .entrypoints .model_free .helpers import (
19+ find_safetensors_index_file ,
20+ gpu_if_available ,
21+ group_files_by_fused_weights ,
22+ )
1823from llmcompressor .entrypoints .model_free .microscale import (
1924 is_microscale_scheme ,
2025)
2126from llmcompressor .entrypoints .model_free .process import (
2227 process_file ,
28+ process_file_group_microscale_scheme ,
2329 process_file_microscale_scheme ,
2430 validate_file ,
2531)
@@ -45,9 +51,14 @@ def model_free_ptq(
4551):
4652 """
4753 Quantize a model without the need for a model definition. This function operates on
48- a model stub or folder containing weights saved in safetensors files
54+ a model stub or folder containing weights saved in safetensors files.
55+
56+ For microscale schemes (NVFP4, MXFP4), fused weight sets (q/k/v, gate/up) are
57+ automatically grouped for joint processing even when split across shards, removing
58+ the need to run reindex_fused_weights as a preprocessing step.
4959
5060 :param model_stub: huggingface model hub or path to local weights files
61+ :param save_directory: directory to save quantized weights to
5162 :param scheme: weight quantization scheme or preset scheme name
5263 :param ignore: modules to ignore. Modules ending with "norm" are automatically
5364 ignored
@@ -64,32 +75,31 @@ def model_free_ptq(
6475 device = gpu_if_available (device )
6576 validate_safetensors_index (model_files , scheme )
6677
67- # 0. collect safetensors files, copy files
68- jobs = []
69- job_fn = (
70- process_file
71- if not is_microscale_scheme (scheme )
72- else process_file_microscale_scheme
73- )
78+ # copy non-safetensors files (configs, tokenizers, etc.)
7479 for file_path , resolved_path in model_files .items ():
75- save_path = Path (save_directory ) / file_path
76-
77- if file_path .endswith ("safetensors" ):
78- jobs .append (
79- (job_fn , resolved_path , save_path , scheme , ignore , device , converter )
80- )
81-
82- else :
80+ if not file_path .endswith ("safetensors" ):
81+ save_path = Path (save_directory ) / file_path
8382 if is_weights_file (file_path ):
8483 logger .warning (f"Skip processing for weights file { file_path } " )
8584 save_path .parent .mkdir (parents = True , exist_ok = True )
86- logger .info (f"Copying { file_path } { save_path } " )
85+ logger .info (f"Copying { file_path } -> { save_path } " )
8786 shutil .copyfile (resolved_path , save_path )
8887
89- # 1. validate quantizable tensors fail fast before long-running quantization
90- exec_jobs (
91- [(validate_file , * job [1 :]) for job in jobs ], max_workers , desc = "Validating"
88+ # build quantization jobs
89+ if is_microscale_scheme (scheme ):
90+ jobs = _build_microscale_jobs (
91+ model_files , save_directory , scheme , ignore , device , converter
92+ )
93+ else :
94+ jobs = _build_standard_jobs (
95+ model_files , save_directory , scheme , ignore , device , converter
96+ )
97+
98+ # 1. validate quantizable tensors — fail fast before long-running quantization
99+ validate_jobs = _make_validate_jobs (
100+ jobs , model_files , scheme , ignore , device , converter
92101 )
102+ exec_jobs (validate_jobs , max_workers , desc = "Validating" )
93103
94104 # 2-5. quantize and compress weights
95105 total_size = 0
@@ -99,6 +109,149 @@ def model_free_ptq(
99109 total_size += _total_size
100110 weight_map .update (_weight_map )
101111
102- # 5 . update config and safetensors index
112+ # 6 . update config and safetensors index
103113 update_config (save_directory , scheme_name , scheme , ignore , converter )
104114 update_safetensors_index (save_directory , total_size , weight_map )
115+
116+
117+ def _build_standard_jobs (
118+ model_files : dict [str , str ],
119+ save_directory : str | os .PathLike ,
120+ scheme : QuantizationScheme ,
121+ ignore : Iterable [str ],
122+ device : torch .device ,
123+ converter : Converter | None ,
124+ ) -> list [tuple ]:
125+ """Build one job per safetensors file for non-microscale schemes."""
126+ jobs = []
127+ for file_path , resolved_path in model_files .items ():
128+ if file_path .endswith ("safetensors" ):
129+ save_path = Path (save_directory ) / file_path
130+ jobs .append (
131+ (
132+ process_file ,
133+ resolved_path ,
134+ save_path ,
135+ scheme ,
136+ ignore ,
137+ device ,
138+ converter ,
139+ )
140+ )
141+ return jobs
142+
143+
144+ def _build_microscale_jobs (
145+ model_files : dict [str , str ],
146+ save_directory : str | os .PathLike ,
147+ scheme : QuantizationScheme ,
148+ ignore : Iterable [str ],
149+ device : torch .device ,
150+ converter : Converter | None ,
151+ ) -> list [tuple ]:
152+ """
153+ Build jobs for microscale schemes, grouping files that share fused weight sets
154+ so that global scale fusion works correctly across shard boundaries.
155+
156+ For models where all fused weights are already co-located in single shards,
157+ each group will be a singleton and process_file_microscale_scheme is used.
158+ For models with cross-shard fused weights, multi-file groups are formed and
159+ process_file_group_microscale_scheme is used, eliminating the need for
160+ reindex_fused_weights preprocessing.
161+ """
162+ index_file = find_safetensors_index_file (model_files )
163+
164+ if index_file is None :
165+ # Single-file model (no index.json) — use standard microscale path
166+ jobs = []
167+ for file_path , resolved_path in model_files .items ():
168+ if file_path .endswith ("safetensors" ):
169+ save_path = Path (save_directory ) / file_path
170+ jobs .append (
171+ (
172+ process_file_microscale_scheme ,
173+ resolved_path ,
174+ save_path ,
175+ scheme ,
176+ ignore ,
177+ device ,
178+ converter ,
179+ )
180+ )
181+ return jobs
182+
183+ # Read weight map to determine cross-shard fused weight groupings
184+ with open (index_file , "r" ) as f :
185+ weight_map : dict [str , str ] = json .load (f )["weight_map" ]
186+
187+ file_groups = group_files_by_fused_weights (weight_map )
188+ jobs = []
189+
190+ for group in file_groups :
191+ if len (group ) == 1 :
192+ # No cross-shard fused weights — use the standard single-file path
193+ shard_name = group [0 ]
194+ resolved_path = model_files [shard_name ]
195+ save_path = Path (save_directory ) / shard_name
196+ jobs .append (
197+ (
198+ process_file_microscale_scheme ,
199+ resolved_path ,
200+ save_path ,
201+ scheme ,
202+ ignore ,
203+ device ,
204+ converter ,
205+ )
206+ )
207+ else :
208+ # Cross-shard fused weights — load group jointly
209+ logger .info (
210+ f"Grouping { len (group )} shards for joint microscale processing "
211+ f"(fused weights span multiple files): { group } "
212+ )
213+ file_paths = [model_files [shard ] for shard in group ]
214+ save_paths = [Path (save_directory ) / shard for shard in group ]
215+ jobs .append (
216+ (
217+ process_file_group_microscale_scheme ,
218+ file_paths ,
219+ save_paths ,
220+ scheme ,
221+ ignore ,
222+ device ,
223+ converter ,
224+ )
225+ )
226+
227+ return jobs
228+
229+
230+ def _make_validate_jobs (
231+ jobs : list [tuple ],
232+ model_files : dict [str , str ],
233+ scheme : QuantizationScheme ,
234+ ignore : Iterable [str ],
235+ device : torch .device ,
236+ converter : Converter | None ,
237+ ) -> list [tuple ]:
238+ """
239+ Build validate_file jobs corresponding to the quantization jobs.
240+ For group jobs, creates one validate_file call per file in the group.
241+ """
242+ validate_jobs = []
243+ for job in jobs :
244+ fn = job [0 ]
245+ if fn is process_file_group_microscale_scheme :
246+ # job = (fn, file_paths, save_paths, scheme, ignore, device, converter)
247+ file_paths , save_paths = job [1 ], job [2 ]
248+ for fp , sp in zip (file_paths , save_paths ):
249+ validate_jobs .append (
250+ (validate_file , fp , sp , scheme , ignore , device , converter )
251+ )
252+ else :
253+ # job = (fn, file_path, save_path, scheme, ignore, device, converter)
254+ validate_jobs .append (
255+ (validate_file , job [1 ], job [2 ], scheme , ignore , device , converter )
256+ )
257+ return validate_jobs
0 commit comments