Skip to content

Commit d80dd38

Browse files
AWQ minor performance improvements to smoothing
Signed-off-by: Brian Dellabetta <[email protected]>
1 parent d497b5a commit d80dd38

File tree

1 file changed

+16
-33
lines changed
  • src/llmcompressor/modifiers/awq

1 file changed

+16
-33
lines changed

src/llmcompressor/modifiers/awq/base.py

Lines changed: 16 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -470,8 +470,8 @@ def _apply_smoothing(self, model: Module) -> None:
470470

471471
# [STEP 3]: Compute output of module
472472
# could cache from hook, rather than recomputing here
473-
fp16_outputs = self._run_samples(parent_module)
474-
if len(fp16_outputs) == 0 or all(f.numel() == 0 for f in fp16_outputs):
473+
fp16_output = self._get_flattened_output(parent_module)
474+
if fp16_output.numel() == 0:
475475
logger.info(
476476
f"Skipping smooth_layer {mapping.smooth_name}, no activations "
477477
"found to scale. This can occasionally occur in MoE models "
@@ -484,7 +484,7 @@ def _apply_smoothing(self, model: Module) -> None:
484484

485485
# [STEP 4]: Compute loss
486486
best_scales = self._compute_best_scale(
487-
x_mean, w_mean, parent_module, balance_layers, fp16_outputs
487+
x_mean, w_mean, parent_module, balance_layers, fp16_output
488488
)
489489

490490
@torch.no_grad()
@@ -552,7 +552,7 @@ def _compute_best_scale(
552552
w_mean: torch.Tensor,
553553
parent_module: torch.nn.Module,
554554
linears2scale: List[torch.nn.Linear],
555-
fp16_outputs: List[torch.Tensor],
555+
fp16_output: torch.Tensor,
556556
) -> torch.Tensor:
557557
"""
558558
Compute loss and select best scales
@@ -616,7 +616,7 @@ def _compute_best_scale(
616616
int_w_outputs = self._run_samples(parent_module)
617617

618618
# compute mean squared error (L2 norm)
619-
loss = self._compute_loss(fp16_outputs, int_w_outputs, device)
619+
loss = _compute_loss(fp16_output, int_w_output)
620620

621621
history.append(loss)
622622
if loss < best_error:
@@ -641,34 +641,6 @@ def _compute_best_scale(
641641

642642
return best_scales.detach().cpu()
643643

644-
@torch.no_grad()
645-
def _compute_loss(
646-
self,
647-
fp16_outputs: List[torch.Tensor],
648-
int_w_outputs: List[torch.Tensor],
649-
device: torch.device,
650-
) -> torch.Tensor:
651-
loss = 0.0
652-
num_elements = 0
653-
654-
# Compute the MSE loss for each batch
655-
for fp16_batch, int_w_batch in zip(fp16_outputs, int_w_outputs):
656-
batch_loss = (
657-
(fp16_batch.to(device) - int_w_batch.to(device))
658-
.view(-1)
659-
.float()
660-
.pow(2)
661-
.sum()
662-
.item()
663-
)
664-
loss += batch_loss
665-
num_elements += fp16_batch.numel()
666-
667-
# Normalize the loss by the total number of elements
668-
loss /= num_elements
669-
670-
return loss
671-
672644
def _assert_all_activations_consumed(self):
673645
"""
674646
Confirm all activations have been consumed
@@ -678,6 +650,17 @@ def _assert_all_activations_consumed(self):
678650
raise RuntimeError("Some cached activations were not used")
679651

680652

653+
@torch.no_grad()
654+
@torch.compile()
655+
def _compute_loss(
656+
fp16_output: torch.Tensor,
657+
int_w_output: torch.Tensor,
658+
) -> torch.Tensor:
659+
"""Compute MSE loss for each batch"""
660+
return (fp16_output - int_w_output).view(-1).float().pow(2).mean()
661+
662+
663+
@torch.compile()
681664
def _pseudo_quantize_tensor(
682665
w: torch.Tensor, symmetric: bool = False, bit_width: int = 8, group_size: int = -1
683666
):

0 commit comments

Comments
 (0)