Skip to content

Commit 8b01f3f

Browse files
ZewenShen-Coheregemini-code-assist[bot]dsikka
authored
[AWQ] Fix _accumulate_mean bug, move AWQ activation averaging off CPU, and improve logging (vllm-project#2161)
This PR addresses the following issues: 1. _accumulate_mean produces incorrect output on its first run. 2. cache_smooth_activations_hook previously performed the averaging computation on the CPU. When both the hidden dimension and sequence length are large, this makes AWQ calibration CPU-bound. The slowdown is especially severe when multiple AWQ quantization jobs run concurrently. 3. Added more informative logging to the AWQ calibration grid search, including per-mapping JSON logs. This PR is a subset of vllm-project#2158 --------- Signed-off-by: ZewenShen-Cohere <zewen.shen@cohere.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Dipika Sikka <dipikasikka1@gmail.com>
1 parent 2c43060 commit 8b01f3f

File tree

1 file changed

+71
-8
lines changed
  • src/llmcompressor/modifiers/awq

1 file changed

+71
-8
lines changed

src/llmcompressor/modifiers/awq/base.py

Lines changed: 71 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,8 @@ class AWQModifier(Modifier, QuantizationMixin):
160160
_smooth_activation_means: dict[str, tuple[torch.FloatTensor, int]] = PrivateAttr(
161161
default_factory=dict
162162
)
163+
# List to store error metrics for each layer
164+
_error_metrics: list[dict] = PrivateAttr(default_factory=list)
163165

164166
def on_initialize(self, state: State, **kwargs) -> bool:
165167
"""
@@ -273,9 +275,12 @@ def on_finalize(self, state: State, **kwargs) -> bool:
273275
if not self.ended_:
274276
self.on_end(state, None)
275277

278+
self._log_error_metrics()
279+
276280
self._parent_args_cache.clear()
277281
self._smooth_activation_means.clear()
278282
self._resolved_mappings.clear()
283+
self._error_metrics.clear()
279284

280285
return True
281286

@@ -386,11 +391,11 @@ def cache_smooth_activations_hook(
386391
args: tuple[torch.Tensor, ...],
387392
_output: torch.Tensor,
388393
):
389-
self._smooth_activation_means[smooth_name] = _accumulate_mean(
390-
# Assume that first argument is the input
391-
args[0].cpu().abs().detach().flatten(0, -2),
394+
act_mean, count = _accumulate_mean(
395+
args[0].abs().detach().flatten(0, -2),
392396
self._smooth_activation_means.get(smooth_name, None),
393397
)
398+
self._smooth_activation_means[smooth_name] = (act_mean.cpu(), count)
394399

395400
return cache_smooth_activations_hook
396401

@@ -555,6 +560,7 @@ def _compute_best_scale(
555560
best_ratio = -1
556561
best_scales = None
557562
best_error = float("inf")
563+
initial_error = None
558564

559565
org_sd = {
560566
k: v.cpu()
@@ -600,11 +606,13 @@ def _compute_best_scale(
600606
],
601607
):
602608
total_iterations = n_grid * len(duo_scalings)
603-
for grid_idx, use_duo_scaling in tqdm(
609+
pbar = tqdm(
604610
product(range(n_grid), duo_scalings),
605611
total=total_iterations,
606-
desc="Grid search",
607-
):
612+
desc=f"Grid search for {mapping.smooth_name}",
613+
leave=False,
614+
)
615+
for grid_idx, use_duo_scaling in pbar:
608616
# create new scales
609617
ratio = grid_idx / n_grid
610618

@@ -668,13 +676,17 @@ def _compute_best_scale(
668676
# compute mean squared error (L2 norm)
669677
loss = self._compute_loss(fp16_outputs, int_w_outputs)
670678

679+
if initial_error is None:
680+
initial_error = loss
681+
671682
history.append(
672683
{"ratio": ratio, "duo_scaling": use_duo_scaling, "error": loss}
673684
)
674685
if loss < best_error:
675686
best_error = loss
676687
best_ratio = ratio
677688
best_scales = scales.clone()
689+
pbar.set_postfix({"best_error": f"{best_error:.3e}"})
678690

679691
mapping.parent.load_state_dict(org_sd, strict=False)
680692

@@ -687,6 +699,25 @@ def _compute_best_scale(
687699
"https://github.com/vllm-project/llm-compressor/issues"
688700
)
689701

702+
err_reduction = best_error / initial_error if initial_error > 0 else 1.0
703+
logger.debug(
704+
f"AWQ grid search for {mapping.smooth_name}: "
705+
f"initial error = {initial_error:.3e}, "
706+
f"best error = {best_error:.3e}, "
707+
f"error reduction rate (best/initial) = {err_reduction * 100:.3f}%"
708+
)
709+
710+
# Store error metrics for this layer
711+
self._error_metrics.append(
712+
{
713+
"layer_name": mapping.smooth_name,
714+
"parent_name": mapping.parent_name,
715+
"initial_error": initial_error,
716+
"best_error": best_error,
717+
"reduction": err_reduction,
718+
}
719+
)
720+
690721
assert (
691722
torch.isnan(best_scales).sum() == 0
692723
), f"Nan found in scales: {best_scales}"
@@ -705,7 +736,7 @@ def _compute_loss(
705736
# Compute the MSE loss for each batch
706737
for fp16_batch, int_w_batch in zip(fp16_outputs, int_w_outputs):
707738
loss += torch.nn.functional.mse_loss(
708-
fp16_batch, int_w_batch.to(fp16_batch.device)
739+
fp16_batch, int_w_batch.to(fp16_batch.device), reduction="sum"
709740
).item()
710741
num_elements += fp16_batch.numel()
711742

@@ -714,6 +745,37 @@ def _compute_loss(
714745

715746
return loss
716747

748+
def _log_error_metrics(self):
749+
"""
750+
Log the error metrics (initial error, best error, reduction).
751+
"""
752+
753+
# Prepare data for saving
754+
metrics_data = {
755+
"quantization_config": {
756+
"duo_scaling": self.duo_scaling,
757+
"n_grid": self.n_grid,
758+
},
759+
"total_layers": len(self._error_metrics),
760+
"metrics": self._error_metrics,
761+
}
762+
763+
# Save to disk
764+
logger.debug(f"AWQ per-mapping error metrics: {metrics_data}")
765+
766+
# Also print summary statistics
767+
reductions = [m["reduction"] for m in self._error_metrics]
768+
avg_reduction = sum(reductions) / len(reductions)
769+
min_reduction = min(reductions)
770+
max_reduction = max(reductions)
771+
sorted_reductions = sorted(reductions)
772+
median_reduction = sorted_reductions[len(sorted_reductions) // 2]
773+
logger.debug(
774+
f"Error reduction statistics: "
775+
f"avg={avg_reduction:.4f}, median={median_reduction:.4f}, "
776+
f"min={min_reduction:.4f}, max={max_reduction:.4f}"
777+
)
778+
717779
def _assert_all_activations_consumed(self):
718780
"""
719781
Confirm all activations have been consumed
@@ -860,9 +922,10 @@ def _accumulate_mean(
860922
sum_added = inp.sum(dim=0)
861923
num_added = inp.size(0)
862924
if prev_mean_and_count is None:
863-
return sum_added, num_added
925+
return sum_added / num_added, num_added
864926

865927
prev_mean, prev_count = prev_mean_and_count
928+
prev_mean = prev_mean.to(inp.device)
866929

867930
prev_sum = prev_mean * prev_count
868931
new_count = prev_count + num_added

0 commit comments

Comments
 (0)