Skip to content

Commit d23090e

Browse files
AWQ minor performance improvements to smoothing
Signed-off-by: Brian Dellabetta <[email protected]>
1 parent e330182 commit d23090e

File tree

1 file changed

+16
-33
lines changed
  • src/llmcompressor/modifiers/awq

1 file changed

+16
-33
lines changed

src/llmcompressor/modifiers/awq/base.py

Lines changed: 16 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -470,8 +470,8 @@ def _apply_smoothing(self, model: Module) -> None:
470470

471471
# [STEP 3]: Compute output of module
472472
# could cache from hook, rather than recomputing here
473-
fp16_outputs = self._run_samples(parent_module)
474-
if len(fp16_outputs) == 0 or all(f.numel() == 0 for f in fp16_outputs):
473+
fp16_output = self._get_flattened_output(parent_module)
474+
if fp16_output.numel() == 0:
475475
logger.info(
476476
f"Skipping smooth_layer {mapping.smooth_name}, no activations "
477477
"found to scale. This can occasionally occur in MoE models "
@@ -498,7 +498,7 @@ def _apply_smoothing(self, model: Module) -> None:
498498

499499
# [STEP 4]: Compute loss
500500
best_scales = self._compute_best_scale(
501-
x_mean, w_mean, parent_module, balance_layers, fp16_outputs
501+
x_mean, w_mean, parent_module, balance_layers, fp16_output
502502
)
503503

504504
@torch.no_grad()
@@ -566,7 +566,7 @@ def _compute_best_scale(
566566
w_mean: torch.Tensor,
567567
parent_module: torch.nn.Module,
568568
linears2scale: List[torch.nn.Linear],
569-
fp16_outputs: List[torch.Tensor],
569+
fp16_output: torch.Tensor,
570570
) -> torch.Tensor:
571571
"""
572572
Compute loss and select best scales
@@ -630,7 +630,7 @@ def _compute_best_scale(
630630
int_w_outputs = self._run_samples(parent_module)
631631

632632
# compute mean squared error (L2 norm)
633-
loss = self._compute_loss(fp16_outputs, int_w_outputs, device)
633+
loss = _compute_loss(fp16_output, int_w_output)
634634

635635
history.append(loss)
636636
if loss < best_error:
@@ -655,34 +655,6 @@ def _compute_best_scale(
655655

656656
return best_scales.detach().cpu()
657657

658-
@torch.no_grad()
659-
def _compute_loss(
660-
self,
661-
fp16_outputs: List[torch.Tensor],
662-
int_w_outputs: List[torch.Tensor],
663-
device: torch.device,
664-
) -> torch.Tensor:
665-
loss = 0.0
666-
num_elements = 0
667-
668-
# Compute the MSE loss for each batch
669-
for fp16_batch, int_w_batch in zip(fp16_outputs, int_w_outputs):
670-
batch_loss = (
671-
(fp16_batch.to(device) - int_w_batch.to(device))
672-
.view(-1)
673-
.float()
674-
.pow(2)
675-
.sum()
676-
.item()
677-
)
678-
loss += batch_loss
679-
num_elements += fp16_batch.numel()
680-
681-
# Normalize the loss by the total number of elements
682-
loss /= num_elements
683-
684-
return loss
685-
686658
def _assert_all_activations_consumed(self):
687659
"""
688660
Confirm all activations have been consumed
@@ -692,6 +664,17 @@ def _assert_all_activations_consumed(self):
692664
raise RuntimeError("Some cached activations were not used")
693665

694666

667+
@torch.no_grad()
668+
@torch.compile()
669+
def _compute_loss(
670+
fp16_output: torch.Tensor,
671+
int_w_output: torch.Tensor,
672+
) -> torch.Tensor:
673+
"""Compute MSE loss for each batch"""
674+
return (fp16_output - int_w_output).view(-1).float().pow(2).mean()
675+
676+
677+
@torch.compile()
695678
def _pseudo_quantize_tensor(
696679
w: torch.Tensor, symmetric: bool = False, bit_width: int = 8, group_size: int = -1
697680
):

0 commit comments

Comments
 (0)