@@ -470,8 +470,8 @@ def _apply_smoothing(self, model: Module) -> None:
470
470
471
471
# [STEP 3]: Compute output of module
472
472
# could cache from hook, rather than recomputing here
473
- fp16_outputs = self ._run_samples (parent_module )
474
- if len ( fp16_outputs ) == 0 or all ( f .numel () == 0 for f in fp16_outputs ) :
473
+ fp16_output = self ._get_flattened_output (parent_module )
474
+ if fp16_output .numel () == 0 :
475
475
logger .info (
476
476
f"Skipping smooth_layer { mapping .smooth_name } , no activations "
477
477
"found to scale. This can occasionally occur in MoE models "
@@ -484,7 +484,7 @@ def _apply_smoothing(self, model: Module) -> None:
484
484
485
485
# [STEP 4]: Compute loss
486
486
best_scales = self ._compute_best_scale (
487
- x_mean , w_mean , parent_module , balance_layers , fp16_outputs
487
+ x_mean , w_mean , parent_module , balance_layers , fp16_output
488
488
)
489
489
490
490
@torch .no_grad ()
@@ -552,7 +552,7 @@ def _compute_best_scale(
552
552
w_mean : torch .Tensor ,
553
553
parent_module : torch .nn .Module ,
554
554
linears2scale : List [torch .nn .Linear ],
555
- fp16_outputs : List [ torch .Tensor ] ,
555
+ fp16_output : torch .Tensor ,
556
556
) -> torch .Tensor :
557
557
"""
558
558
Compute loss and select best scales
@@ -616,7 +616,7 @@ def _compute_best_scale(
616
616
int_w_outputs = self ._run_samples (parent_module )
617
617
618
618
# compute mean squared error (L2 norm)
619
- loss = self . _compute_loss (fp16_outputs , int_w_outputs , device )
619
+ loss = _compute_loss (fp16_output , int_w_output )
620
620
621
621
history .append (loss )
622
622
if loss < best_error :
@@ -641,34 +641,6 @@ def _compute_best_scale(
641
641
642
642
return best_scales .detach ().cpu ()
643
643
644
- @torch .no_grad ()
645
- def _compute_loss (
646
- self ,
647
- fp16_outputs : List [torch .Tensor ],
648
- int_w_outputs : List [torch .Tensor ],
649
- device : torch .device ,
650
- ) -> torch .Tensor :
651
- loss = 0.0
652
- num_elements = 0
653
-
654
- # Compute the MSE loss for each batch
655
- for fp16_batch , int_w_batch in zip (fp16_outputs , int_w_outputs ):
656
- batch_loss = (
657
- (fp16_batch .to (device ) - int_w_batch .to (device ))
658
- .view (- 1 )
659
- .float ()
660
- .pow (2 )
661
- .sum ()
662
- .item ()
663
- )
664
- loss += batch_loss
665
- num_elements += fp16_batch .numel ()
666
-
667
- # Normalize the loss by the total number of elements
668
- loss /= num_elements
669
-
670
- return loss
671
-
672
644
def _assert_all_activations_consumed (self ):
673
645
"""
674
646
Confirm all activations have been consumed
@@ -678,6 +650,17 @@ def _assert_all_activations_consumed(self):
678
650
raise RuntimeError ("Some cached activations were not used" )
679
651
680
652
653
+ @torch .no_grad ()
654
+ @torch .compile ()
655
+ def _compute_loss (
656
+ fp16_output : torch .Tensor ,
657
+ int_w_output : torch .Tensor ,
658
+ ) -> torch .Tensor :
659
+ """Compute MSE loss for each batch"""
660
+ return (fp16_output - int_w_output ).view (- 1 ).float ().pow (2 ).mean ()
661
+
662
+
663
+ @torch .compile ()
681
664
def _pseudo_quantize_tensor (
682
665
w : torch .Tensor , symmetric : bool = False , bit_width : int = 8 , group_size : int = - 1
683
666
):
0 commit comments