@@ -122,7 +122,7 @@ def process_file_microscale_scheme(
122122
123123 :param file_path: safetensors file to process
124124 :param save_path: save path of file with quantized weights
125- :param scheme: quantization scheme to apply to tensors
125+ :param scheme: microscale quantization scheme (NVFP4, MXFP4)
126126 :param ignore: modules to ignore. Modules ending with "norm" are automatically
127127 ignored
128128 :param device: device used to quantize and compress weights
@@ -138,61 +138,7 @@ def process_file_microscale_scheme(
138138 fused_sets , unmatched_sets = get_fused_names (tensors )
139139 assert len (unmatched_sets ) <= 0 # should be caught by validate_safetensors_index
140140
141- fused_name_to_fused_index : dict [str , int ] # fused_name -> fused_index
142- fused_modules : dict [int , dict [str , Module ]] # fused_index -> named_modules
143-
144- fused_name_to_fused_index = {
145- name : index
146- for index , matched_set in enumerate (fused_sets )
147- for name in matched_set .values ()
148- }
149- fused_modules = defaultdict (dict )
150-
151- for module_name , name in match_quantizable_tensors (tensors , ignore , scheme .targets ):
152- validate_weight_for_quantization (tensors [name ], scheme , name )
153-
154- # 1. initialize module with qparams (on device)
155- module = initialize_quantized_linear (tensors [name ], scheme , device )
156-
157- # 2. calibrate weight qparams. Delay scale/zp calibration for fused modules
158- calibrate_global_scale (module )
159- if name in fused_name_to_fused_index :
160- fused_index = fused_name_to_fused_index [name ]
161- fused_modules [fused_index ][name ] = module
162- continue
163-
164- calibrate_scale_zp (module )
165-
166- # 3. compress module using qparams
167- compress_module (module )
168-
169- # 4. save compressed data (on cpu)
170- del tensors [name ]
171- prefix = module_name + "."
172- for key , value in module .state_dict (prefix = prefix ).items ():
173- tensors [key ] = value .to ("cpu" )
174-
175- # compress and save microscale fused modules
176- for named_modules in fused_modules .values ():
177- # 2.1. fuse global scales
178- global_scales = [m .weight_global_scale for m in named_modules .values ()]
179- fused_global_scale = torch .min (torch .cat (global_scales , dim = 0 ))
180-
181- for name , module in named_modules .items ():
182- module_name , _ = name .rsplit ("." , 1 )
183- module .weight_global_scale .data .copy_ (fused_global_scale )
184-
185- # 2.2. finish calibration with fused global scales
186- calibrate_scale_zp (module )
187-
188- # 3. compress module using microscale qparams
189- compress_module (module )
190-
191- # 4. save compressed data (on cpu)
192- del tensors [name ]
193- prefix = module_name + "."
194- for key , value in module .state_dict (prefix = prefix ).items ():
195- tensors [key ] = value .to ("cpu" )
141+ tensors , _ = _process_tensors_microscale (tensors , scheme , ignore , device )
196142
197143 save_file (tensors , save_path )
198144 total_size = sum (tensor .nbytes for tensor in tensors .values ())
@@ -231,9 +177,9 @@ def process_file_group_microscale_scheme(
231177 "Use `process_file` or `process_file_microscale_scheme` for "
232178 "non-microscale schemes"
233179 )
234- assert len (file_paths ) == len (
235- save_paths
236- ), "file_paths and save_paths must have the same length"
180+ assert len (file_paths ) == len (save_paths ), (
181+ "file_paths and save_paths must have the same length"
182+ )
237183
238184 # Load all tensors from the group, tracking which output shard each belongs to
239185 tensor_to_shard : dict [str , str ] = {}
@@ -254,6 +200,54 @@ def process_file_group_microscale_scheme(
254200 "This is a bug in group_files_by_fused_weights."
255201 )
256202
203+ tensors , tensor_to_shard = _process_tensors_microscale (
204+ tensors , scheme , ignore , device , tensor_to_shard
205+ )
206+
207+ # Re-shard: write each tensor back to its original output file
208+ output_shards : dict [str , dict [str , torch .Tensor ]] = defaultdict (dict )
209+ for name , tensor in tensors .items ():
210+ output_shards [tensor_to_shard [name ]][name ] = tensor
211+
212+ total_size = 0
213+ weight_map : dict [str , str ] = {}
214+ for save_path in save_paths :
215+ shard_name = os .path .basename (save_path )
216+ shard_tensors = output_shards .get (shard_name , {})
217+ os .makedirs (os .path .dirname (os .path .abspath (save_path )), exist_ok = True )
218+ save_file (shard_tensors , save_path )
219+ total_size += sum (t .nbytes for t in shard_tensors .values ())
220+ weight_map .update ({k : shard_name for k in shard_tensors })
221+
222+ return total_size , weight_map
223+
224+
225+ def _process_tensors_microscale (
226+ tensors : dict [str , torch .Tensor ],
227+ scheme : QuantizationScheme ,
228+ ignore : Iterable [str ],
229+ device : str | torch .device ,
230+ tensor_to_shard : dict [str , str ] | None = None ,
231+ ) -> tuple [dict [str , torch .Tensor ], dict [str , str ] | None ]:
232+ """
233+ Core microscale quantization logic shared by process_file_microscale_scheme
234+ and process_file_group_microscale_scheme.
235+
236+ Processes all quantizable tensors in the given dict in-place, handling
237+ global scale fusion for fused weight sets (q/k/v, gate/up). When
238+ tensor_to_shard is provided, shard assignments are updated to follow
239+ compressed tensor keys.
240+
241+ :param tensors: dict of tensor name -> tensor, modified in-place
242+ :param scheme: microscale quantization scheme (NVFP4, MXFP4)
243+ :param ignore: modules to ignore
244+ :param device: device used to quantize and compress weights
245+ :param tensor_to_shard: optional mapping of tensor name -> shard filename,
246+ updated in-place when compressed tensors produce new keys
247+ :return: (tensors, tensor_to_shard) tuple with updated contents
248+ """
249+ fused_sets , _ = get_fused_names (list (tensors .keys ()))
250+
257251 fused_name_to_fused_index : dict [str , int ] = {
258252 name : index
259253 for index , matched_set in enumerate (fused_sets )
@@ -280,13 +274,14 @@ def process_file_group_microscale_scheme(
280274 # 3. compress module using qparams
281275 compress_module (module )
282276
283- # 4. save compressed data back to cpu, preserving shard assignment
284- original_shard = tensor_to_shard [name ]
277+ # 4. save compressed data back to cpu
278+ original_shard = tensor_to_shard [name ] if tensor_to_shard else None
285279 del tensors [name ]
286280 prefix = module_name + "."
287281 for key , value in module .state_dict (prefix = prefix ).items ():
288282 tensors [key ] = value .to ("cpu" )
289- tensor_to_shard [key ] = original_shard
283+ if tensor_to_shard is not None :
284+ tensor_to_shard [key ] = original_shard
290285
291286 # compress and save microscale fused modules (with fused global scales)
292287 for named_modules in fused_modules .values ():
@@ -304,27 +299,13 @@ def process_file_group_microscale_scheme(
304299 # 3. compress module using microscale qparams
305300 compress_module (module )
306301
307- # 4. save compressed data back to cpu, preserving shard assignment
308- original_shard = tensor_to_shard [name ]
302+ # 4. save compressed data back to cpu
303+ original_shard = tensor_to_shard [name ] if tensor_to_shard else None
309304 del tensors [name ]
310305 prefix = module_name + "."
311306 for key , value in module .state_dict (prefix = prefix ).items ():
312307 tensors [key ] = value .to ("cpu" )
313- tensor_to_shard [key ] = original_shard
314-
315- # Re-shard: write each tensor back to its original output file
316- output_shards : dict [str , dict [str , torch .Tensor ]] = defaultdict (dict )
317- for name , tensor in tensors .items ():
318- output_shards [tensor_to_shard [name ]][name ] = tensor
319-
320- total_size = 0
321- weight_map : dict [str , str ] = {}
322- for save_path in save_paths :
323- shard_name = os .path .basename (save_path )
324- shard_tensors = output_shards .get (shard_name , {})
325- os .makedirs (os .path .dirname (os .path .abspath (save_path )), exist_ok = True )
326- save_file (shard_tensors , save_path )
327- total_size += sum (t .nbytes for t in shard_tensors .values ())
328- weight_map .update ({k : shard_name for k in shard_tensors })
308+ if tensor_to_shard is not None :
309+ tensor_to_shard [key ] = original_shard
329310
330- return total_size , weight_map
311+ return tensors , tensor_to_shard
0 commit comments