|
3 | 3 |
|
4 | 4 | import torch
|
5 | 5 | from compressed_tensors.quantization import disable_quantization
|
| 6 | +import torch.nn.functional as F |
6 | 7 | from compressed_tensors.utils import (
|
7 | 8 | align_modules,
|
8 | 9 | get_execution_device,
|
@@ -593,9 +594,9 @@ def _compute_best_scale(
|
593 | 594 | x_mean = x_mean.view(-1).to(device)
|
594 | 595 | w_mean = w_mean.view(-1).to(device)
|
595 | 596 |
|
596 |
| - for ratio in range(n_grid): |
| 597 | + for grid_idx in range(n_grid): |
597 | 598 | # create new scales
|
598 |
| - ratio = ratio / n_grid |
| 599 | + ratio = grid_idx / n_grid |
599 | 600 |
|
600 | 601 | # NOTE: s^-1 * x is fused here, according to paper
|
601 | 602 | if self.duo_scaling:
|
@@ -630,7 +631,7 @@ def _compute_best_scale(
|
630 | 631 | int_w_outputs = self._run_samples(parent_module)
|
631 | 632 |
|
632 | 633 | # compute mean squared error (L2 norm)
|
633 |
| - loss = _compute_loss(fp16_output, int_w_output) |
| 634 | + loss = F.mse_loss(int_w_output, fp16_output).item() |
634 | 635 |
|
635 | 636 | history.append(loss)
|
636 | 637 | if loss < best_error:
|
@@ -664,18 +665,6 @@ def _assert_all_activations_consumed(self):
|
664 | 665 | raise RuntimeError("Some cached activations were not used")
|
665 | 666 |
|
666 | 667 |
|
667 |
| -@torch.no_grad() |
668 |
| -@torch.compile() |
669 |
| -def _compute_loss( |
670 |
| - fp16_output: torch.Tensor, |
671 |
| - int_w_output: torch.Tensor, |
672 |
| -) -> torch.Tensor: |
673 |
| - """ |
674 |
| - Compute MSE loss over the flattened output of all batches |
675 |
| - """ |
676 |
| - return (fp16_output - int_w_output).view(-1).float().pow(2).mean() |
677 |
| - |
678 |
| - |
679 | 668 | @torch.compile()
|
680 | 669 | def _pseudo_quantize_tensor(
|
681 | 670 | w: torch.Tensor, symmetric: bool = False, bit_width: int = 8, group_size: int = -1
|
|
0 commit comments