Skip to content

Commit c1c27d0

Browse files
switch to F.mse_loss()
Signed-off-by: Brian Dellabetta <[email protected]>
1 parent 8293f18 commit c1c27d0

File tree

1 file changed

+4
-15
lines changed
  • src/llmcompressor/modifiers/awq

1 file changed

+4
-15
lines changed

src/llmcompressor/modifiers/awq/base.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import torch
55
from compressed_tensors.quantization import disable_quantization
6+
import torch.nn.functional as F
67
from compressed_tensors.utils import (
78
align_modules,
89
get_execution_device,
@@ -593,9 +594,9 @@ def _compute_best_scale(
593594
x_mean = x_mean.view(-1).to(device)
594595
w_mean = w_mean.view(-1).to(device)
595596

596-
for ratio in range(n_grid):
597+
for grid_idx in range(n_grid):
597598
# create new scales
598-
ratio = ratio / n_grid
599+
ratio = grid_idx / n_grid
599600

600601
# NOTE: s^-1 * x is fused here, according to paper
601602
if self.duo_scaling:
@@ -630,7 +631,7 @@ def _compute_best_scale(
630631
int_w_outputs = self._run_samples(parent_module)
631632

632633
# compute mean squared error (L2 norm)
633-
loss = _compute_loss(fp16_output, int_w_output)
634+
loss = F.mse_loss(int_w_output, fp16_output).item()
634635

635636
history.append(loss)
636637
if loss < best_error:
@@ -664,18 +665,6 @@ def _assert_all_activations_consumed(self):
664665
raise RuntimeError("Some cached activations were not used")
665666

666667

667-
@torch.no_grad()
668-
@torch.compile()
669-
def _compute_loss(
670-
fp16_output: torch.Tensor,
671-
int_w_output: torch.Tensor,
672-
) -> torch.Tensor:
673-
"""
674-
Compute MSE loss over the flattened output of all batches
675-
"""
676-
return (fp16_output - int_w_output).view(-1).float().pow(2).mean()
677-
678-
679668
@torch.compile()
680669
def _pseudo_quantize_tensor(
681670
w: torch.Tensor, symmetric: bool = False, bit_width: int = 8, group_size: int = -1

0 commit comments

Comments
 (0)