Skip to content

Commit 11240f1

Browse files
committed
refactor: extract _process_tensors_microscale to reduce duplication
Shared microscale processing logic now lives in _process_tensors_microscale, called by both process_file_microscale_scheme and process_file_group_microscale_scheme. Signed-off-by: David Zheng <dqzheng1996@gmail.com>
1 parent 651a6d3 commit 11240f1

File tree

3 files changed

+74
-92
lines changed

3 files changed

+74
-92
lines changed

src/llmcompressor/entrypoints/model_free/process.py

Lines changed: 62 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def process_file_microscale_scheme(
122122
123123
:param file_path: safetensors file to process
124124
:param save_path: save path of file with quantized weights
125-
:param scheme: quantization scheme to apply to tensors
125+
:param scheme: microscale quantization scheme (NVFP4, MXFP4)
126126
:param ignore: modules to ignore. Modules ending with "norm" are automatically
127127
ignored
128128
:param device: device used to quantize and compress weights
@@ -138,61 +138,7 @@ def process_file_microscale_scheme(
138138
fused_sets, unmatched_sets = get_fused_names(tensors)
139139
assert len(unmatched_sets) <= 0 # should be caught by validate_safetensors_index
140140

141-
fused_name_to_fused_index: dict[str, int] # fused_name -> fused_index
142-
fused_modules: dict[int, dict[str, Module]] # fused_index -> named_modules
143-
144-
fused_name_to_fused_index = {
145-
name: index
146-
for index, matched_set in enumerate(fused_sets)
147-
for name in matched_set.values()
148-
}
149-
fused_modules = defaultdict(dict)
150-
151-
for module_name, name in match_quantizable_tensors(tensors, ignore, scheme.targets):
152-
validate_weight_for_quantization(tensors[name], scheme, name)
153-
154-
# 1. initialize module with qparams (on device)
155-
module = initialize_quantized_linear(tensors[name], scheme, device)
156-
157-
# 2. calibrate weight qparams. Delay scale/zp calibration for fused modules
158-
calibrate_global_scale(module)
159-
if name in fused_name_to_fused_index:
160-
fused_index = fused_name_to_fused_index[name]
161-
fused_modules[fused_index][name] = module
162-
continue
163-
164-
calibrate_scale_zp(module)
165-
166-
# 3. compress module using qparams
167-
compress_module(module)
168-
169-
# 4. save compressed data (on cpu)
170-
del tensors[name]
171-
prefix = module_name + "."
172-
for key, value in module.state_dict(prefix=prefix).items():
173-
tensors[key] = value.to("cpu")
174-
175-
# compress and save microscale fused modules
176-
for named_modules in fused_modules.values():
177-
# 2.1. fuse global scales
178-
global_scales = [m.weight_global_scale for m in named_modules.values()]
179-
fused_global_scale = torch.min(torch.cat(global_scales, dim=0))
180-
181-
for name, module in named_modules.items():
182-
module_name, _ = name.rsplit(".", 1)
183-
module.weight_global_scale.data.copy_(fused_global_scale)
184-
185-
# 2.2. finish calibration with fused global scales
186-
calibrate_scale_zp(module)
187-
188-
# 3. compress module using microscale qparams
189-
compress_module(module)
190-
191-
# 4. save compressed data (on cpu)
192-
del tensors[name]
193-
prefix = module_name + "."
194-
for key, value in module.state_dict(prefix=prefix).items():
195-
tensors[key] = value.to("cpu")
141+
tensors, _ = _process_tensors_microscale(tensors, scheme, ignore, device)
196142

197143
save_file(tensors, save_path)
198144
total_size = sum(tensor.nbytes for tensor in tensors.values())
@@ -231,9 +177,9 @@ def process_file_group_microscale_scheme(
231177
"Use `process_file` or `process_file_microscale_scheme` for "
232178
"non-microscale schemes"
233179
)
234-
assert len(file_paths) == len(
235-
save_paths
236-
), "file_paths and save_paths must have the same length"
180+
assert len(file_paths) == len(save_paths), (
181+
"file_paths and save_paths must have the same length"
182+
)
237183

238184
# Load all tensors from the group, tracking which output shard each belongs to
239185
tensor_to_shard: dict[str, str] = {}
@@ -254,6 +200,54 @@ def process_file_group_microscale_scheme(
254200
"This is a bug in group_files_by_fused_weights."
255201
)
256202

203+
tensors, tensor_to_shard = _process_tensors_microscale(
204+
tensors, scheme, ignore, device, tensor_to_shard
205+
)
206+
207+
# Re-shard: write each tensor back to its original output file
208+
output_shards: dict[str, dict[str, torch.Tensor]] = defaultdict(dict)
209+
for name, tensor in tensors.items():
210+
output_shards[tensor_to_shard[name]][name] = tensor
211+
212+
total_size = 0
213+
weight_map: dict[str, str] = {}
214+
for save_path in save_paths:
215+
shard_name = os.path.basename(save_path)
216+
shard_tensors = output_shards.get(shard_name, {})
217+
os.makedirs(os.path.dirname(os.path.abspath(save_path)), exist_ok=True)
218+
save_file(shard_tensors, save_path)
219+
total_size += sum(t.nbytes for t in shard_tensors.values())
220+
weight_map.update({k: shard_name for k in shard_tensors})
221+
222+
return total_size, weight_map
223+
224+
225+
def _process_tensors_microscale(
226+
tensors: dict[str, torch.Tensor],
227+
scheme: QuantizationScheme,
228+
ignore: Iterable[str],
229+
device: str | torch.device,
230+
tensor_to_shard: dict[str, str] | None = None,
231+
) -> tuple[dict[str, torch.Tensor], dict[str, str] | None]:
232+
"""
233+
Core microscale quantization logic shared by process_file_microscale_scheme
234+
and process_file_group_microscale_scheme.
235+
236+
Processes all quantizable tensors in the given dict in-place, handling
237+
global scale fusion for fused weight sets (q/k/v, gate/up). When
238+
tensor_to_shard is provided, shard assignments are updated to follow
239+
compressed tensor keys.
240+
241+
:param tensors: dict of tensor name -> tensor, modified in-place
242+
:param scheme: microscale quantization scheme (NVFP4, MXFP4)
243+
:param ignore: modules to ignore
244+
:param device: device used to quantize and compress weights
245+
:param tensor_to_shard: optional mapping of tensor name -> shard filename,
246+
updated in-place when compressed tensors produce new keys
247+
:return: (tensors, tensor_to_shard) tuple with updated contents
248+
"""
249+
fused_sets, _ = get_fused_names(list(tensors.keys()))
250+
257251
fused_name_to_fused_index: dict[str, int] = {
258252
name: index
259253
for index, matched_set in enumerate(fused_sets)
@@ -280,13 +274,14 @@ def process_file_group_microscale_scheme(
280274
# 3. compress module using qparams
281275
compress_module(module)
282276

283-
# 4. save compressed data back to cpu, preserving shard assignment
284-
original_shard = tensor_to_shard[name]
277+
# 4. save compressed data back to cpu
278+
original_shard = tensor_to_shard[name] if tensor_to_shard else None
285279
del tensors[name]
286280
prefix = module_name + "."
287281
for key, value in module.state_dict(prefix=prefix).items():
288282
tensors[key] = value.to("cpu")
289-
tensor_to_shard[key] = original_shard
283+
if tensor_to_shard is not None:
284+
tensor_to_shard[key] = original_shard
290285

291286
# compress and save microscale fused modules (with fused global scales)
292287
for named_modules in fused_modules.values():
@@ -304,27 +299,13 @@ def process_file_group_microscale_scheme(
304299
# 3. compress module using microscale qparams
305300
compress_module(module)
306301

307-
# 4. save compressed data back to cpu, preserving shard assignment
308-
original_shard = tensor_to_shard[name]
302+
# 4. save compressed data back to cpu
303+
original_shard = tensor_to_shard[name] if tensor_to_shard else None
309304
del tensors[name]
310305
prefix = module_name + "."
311306
for key, value in module.state_dict(prefix=prefix).items():
312307
tensors[key] = value.to("cpu")
313-
tensor_to_shard[key] = original_shard
314-
315-
# Re-shard: write each tensor back to its original output file
316-
output_shards: dict[str, dict[str, torch.Tensor]] = defaultdict(dict)
317-
for name, tensor in tensors.items():
318-
output_shards[tensor_to_shard[name]][name] = tensor
319-
320-
total_size = 0
321-
weight_map: dict[str, str] = {}
322-
for save_path in save_paths:
323-
shard_name = os.path.basename(save_path)
324-
shard_tensors = output_shards.get(shard_name, {})
325-
os.makedirs(os.path.dirname(os.path.abspath(save_path)), exist_ok=True)
326-
save_file(shard_tensors, save_path)
327-
total_size += sum(t.nbytes for t in shard_tensors.values())
328-
weight_map.update({k: shard_name for k in shard_tensors})
308+
if tensor_to_shard is not None:
309+
tensor_to_shard[key] = original_shard
329310

330-
return total_size, weight_map
311+
return tensors, tensor_to_shard

tests/llmcompressor/entrypoints/model_free/test_reindexing_elimination.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,16 @@ def qkv_tensors(self):
122122
}
123123

124124
def _save_split_shards(self, tmp_path, tensors):
125-
shard1 = {"model.layers.0.self_attn.q_proj.weight":
126-
tensors["model.layers.0.self_attn.q_proj.weight"]}
127-
shard2 = {k: v for k, v in tensors.items()
128-
if k != "model.layers.0.self_attn.q_proj.weight"}
125+
shard1 = {
126+
"model.layers.0.self_attn.q_proj.weight": tensors[
127+
"model.layers.0.self_attn.q_proj.weight"
128+
]
129+
}
130+
shard2 = {
131+
k: v
132+
for k, v in tensors.items()
133+
if k != "model.layers.0.self_attn.q_proj.weight"
134+
}
129135
shard1_path = tmp_path / "shard-00001.safetensors"
130136
shard2_path = tmp_path / "shard-00002.safetensors"
131137
save_file(shard1, shard1_path)
@@ -175,9 +181,7 @@ def test_group_processing_produces_same_keys_as_single_shard(
175181

176182
assert set(weight_map_group.keys()) == set(weight_map_merged.keys())
177183

178-
def test_group_processing_preserves_original_sharding(
179-
self, qkv_tensors, tmp_path
180-
):
184+
def test_group_processing_preserves_original_sharding(self, qkv_tensors, tmp_path):
181185
scheme = _make_nvfp4_scheme()
182186
split_dir = tmp_path / "split"
183187
split_dir.mkdir()
@@ -201,9 +205,7 @@ def test_group_processing_preserves_original_sharding(
201205
assert save_path.exists()
202206
assert save_path.stat().st_size > 0
203207

204-
def test_group_processing_total_size_matches_merged(
205-
self, qkv_tensors, tmp_path
206-
):
208+
def test_group_processing_total_size_matches_merged(self, qkv_tensors, tmp_path):
207209
scheme = _make_nvfp4_scheme()
208210
split_dir = tmp_path / "split"
209211
split_dir.mkdir()

tests/llmcompressor/modifiers/awq/test_base.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -709,4 +709,3 @@ def test_search_observer_invalid_rejected():
709709

710710
with pytest.raises(ValidationError, match="search_observer must be one of"):
711711
AWQModifier(scheme="W4A16_ASYM", search_observer="invalid_observer")
712-

0 commit comments

Comments
 (0)