Skip to content

Commit 6f951b7

Browse files
codeassist updates
Signed-off-by: Brian Dellabetta <[email protected]>
1 parent d80dd38 commit 6f951b7

File tree

1 file changed

+3
-1
lines changed
  • src/llmcompressor/modifiers/awq

1 file changed

+3
-1
lines changed

src/llmcompressor/modifiers/awq/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -656,7 +656,9 @@ def _compute_loss(
656656
fp16_output: torch.Tensor,
657657
int_w_output: torch.Tensor,
658658
) -> torch.Tensor:
659-
"""Compute MSE loss for each batch"""
659+
"""
660+
Compute MSE loss over the flattened output of all batches
661+
"""
660662
return (fp16_output - int_w_output).view(-1).float().pow(2).mean()
661663

662664

0 commit comments

Comments
 (0)