@@ -160,6 +160,8 @@ class AWQModifier(Modifier, QuantizationMixin):
160160 _smooth_activation_means : dict [str , tuple [torch .FloatTensor , int ]] = PrivateAttr (
161161 default_factory = dict
162162 )
163+ # List to store error metrics for each layer
164+ _error_metrics : list [dict ] = PrivateAttr (default_factory = list )
163165
164166 def on_initialize (self , state : State , ** kwargs ) -> bool :
165167 """
@@ -273,9 +275,12 @@ def on_finalize(self, state: State, **kwargs) -> bool:
273275 if not self .ended_ :
274276 self .on_end (state , None )
275277
278+ self ._log_error_metrics ()
279+
276280 self ._parent_args_cache .clear ()
277281 self ._smooth_activation_means .clear ()
278282 self ._resolved_mappings .clear ()
283+ self ._error_metrics .clear ()
279284
280285 return True
281286
@@ -386,11 +391,11 @@ def cache_smooth_activations_hook(
386391 args : tuple [torch .Tensor , ...],
387392 _output : torch .Tensor ,
388393 ):
389- self ._smooth_activation_means [smooth_name ] = _accumulate_mean (
390- # Assume that first argument is the input
391- args [0 ].cpu ().abs ().detach ().flatten (0 , - 2 ),
394+ act_mean , count = _accumulate_mean (
395+ args [0 ].abs ().detach ().flatten (0 , - 2 ),
392396 self ._smooth_activation_means .get (smooth_name , None ),
393397 )
398+ self ._smooth_activation_means [smooth_name ] = (act_mean .cpu (), count )
394399
395400 return cache_smooth_activations_hook
396401
@@ -555,6 +560,7 @@ def _compute_best_scale(
555560 best_ratio = - 1
556561 best_scales = None
557562 best_error = float ("inf" )
563+ initial_error = None
558564
559565 org_sd = {
560566 k : v .cpu ()
@@ -600,11 +606,13 @@ def _compute_best_scale(
600606 ],
601607 ):
602608 total_iterations = n_grid * len (duo_scalings )
603- for grid_idx , use_duo_scaling in tqdm (
609+ pbar = tqdm (
604610 product (range (n_grid ), duo_scalings ),
605611 total = total_iterations ,
606- desc = "Grid search" ,
607- ):
612+ desc = f"Grid search for { mapping .smooth_name } " ,
613+ leave = False ,
614+ )
615+ for grid_idx , use_duo_scaling in pbar :
608616 # create new scales
609617 ratio = grid_idx / n_grid
610618
@@ -668,13 +676,17 @@ def _compute_best_scale(
668676 # compute mean squared error (L2 norm)
669677 loss = self ._compute_loss (fp16_outputs , int_w_outputs )
670678
679+ if initial_error is None :
680+ initial_error = loss
681+
671682 history .append (
672683 {"ratio" : ratio , "duo_scaling" : use_duo_scaling , "error" : loss }
673684 )
674685 if loss < best_error :
675686 best_error = loss
676687 best_ratio = ratio
677688 best_scales = scales .clone ()
689+ pbar .set_postfix ({"best_error" : f"{ best_error :.3e} " })
678690
679691 mapping .parent .load_state_dict (org_sd , strict = False )
680692
@@ -687,6 +699,25 @@ def _compute_best_scale(
687699 "https://github.com/vllm-project/llm-compressor/issues"
688700 )
689701
702+ err_reduction = best_error / initial_error if initial_error > 0 else 1.0
703+ logger .debug (
704+ f"AWQ grid search for { mapping .smooth_name } : "
705+ f"initial error = { initial_error :.3e} , "
706+ f"best error = { best_error :.3e} , "
707+ f"error reduction rate (best/initial) = { err_reduction * 100 :.3f} %"
708+ )
709+
710+ # Store error metrics for this layer
711+ self ._error_metrics .append (
712+ {
713+ "layer_name" : mapping .smooth_name ,
714+ "parent_name" : mapping .parent_name ,
715+ "initial_error" : initial_error ,
716+ "best_error" : best_error ,
717+ "reduction" : err_reduction ,
718+ }
719+ )
720+
690721 assert (
691722 torch .isnan (best_scales ).sum () == 0
692723 ), f"Nan found in scales: { best_scales } "
@@ -705,7 +736,7 @@ def _compute_loss(
705736 # Compute the MSE loss for each batch
706737 for fp16_batch , int_w_batch in zip (fp16_outputs , int_w_outputs ):
707738 loss += torch .nn .functional .mse_loss (
708- fp16_batch , int_w_batch .to (fp16_batch .device )
739+ fp16_batch , int_w_batch .to (fp16_batch .device ), reduction = "sum"
709740 ).item ()
710741 num_elements += fp16_batch .numel ()
711742
@@ -714,6 +745,37 @@ def _compute_loss(
714745
715746 return loss
716747
748+ def _log_error_metrics (self ):
749+ """
750+ Log the error metrics (initial error, best error, reduction).
751+ """
752+
753+ # Prepare data for saving
754+ metrics_data = {
755+ "quantization_config" : {
756+ "duo_scaling" : self .duo_scaling ,
757+ "n_grid" : self .n_grid ,
758+ },
759+ "total_layers" : len (self ._error_metrics ),
760+ "metrics" : self ._error_metrics ,
761+ }
762+
763+ # Save to disk
764+ logger .debug (f"AWQ per-mapping error metrics: { metrics_data } " )
765+
766+ # Also print summary statistics
767+ reductions = [m ["reduction" ] for m in self ._error_metrics ]
768+ avg_reduction = sum (reductions ) / len (reductions )
769+ min_reduction = min (reductions )
770+ max_reduction = max (reductions )
771+ sorted_reductions = sorted (reductions )
772+ median_reduction = sorted_reductions [len (sorted_reductions ) // 2 ]
773+ logger .debug (
774+ f"Error reduction statistics: "
775+ f"avg={ avg_reduction :.4f} , median={ median_reduction :.4f} , "
776+ f"min={ min_reduction :.4f} , max={ max_reduction :.4f} "
777+ )
778+
717779 def _assert_all_activations_consumed (self ):
718780 """
719781 Confirm all activations have been consumed
@@ -860,9 +922,10 @@ def _accumulate_mean(
860922 sum_added = inp .sum (dim = 0 )
861923 num_added = inp .size (0 )
862924 if prev_mean_and_count is None :
863- return sum_added , num_added
925+ return sum_added / num_added , num_added
864926
865927 prev_mean , prev_count = prev_mean_and_count
928+ prev_mean = prev_mean .to (inp .device )
866929
867930 prev_sum = prev_mean * prev_count
868931 new_count = prev_count + num_added
0 commit comments